Overview
DeepLabHead implements the DeepLabV3 segmentation head architecture, including the Atrous Spatial Pyramid Pooling (ASPP) module with multiple dilation rates for multi-scale feature extraction. The module comprises four key classes: DeepLabHead, ASPPConv, ASPPPooling, and ASPP, all built on torch.nn.
Description
This module provides the segmentation head for DeepLabV3, a state-of-the-art semantic segmentation architecture. The core innovation is the Atrous Spatial Pyramid Pooling (ASPP) module, which applies parallel atrous (dilated) convolutions at multiple rates to capture features at different spatial scales. This multi-scale approach enables the model to segment objects of varying sizes within a single forward pass.
Architecture Components
- DeepLabHead (lines 6-14): Top-level segmentation head that chains ASPP with a final classification convolution
- ASPPConv (lines 17-24): Single atrous convolution block with batch normalization and ReLU activation
- ASPPPooling (lines 27-39): Global average pooling branch that captures image-level context features
- ASPP (lines 42-70): The full Atrous Spatial Pyramid Pooling module that combines multiple parallel branches
Code Reference
Source Location
| File |
Lines |
Repository
|
examples/image_segmenter/deeplabv3/deeplabv3.py |
L1-70 |
pytorch/serve
|
Key Classes
class DeepLabHead(nn.Sequential):
"""
DeepLabV3 segmentation head.
Lines 6-14.
Chains ASPP module with a 3x3 Conv -> BN -> ReLU -> 1x1 classifier.
"""
def __init__(self, in_channels, num_classes, atrous_rates):
...
class ASPPConv(nn.Sequential):
"""
Single atrous convolution block.
Lines 17-24.
Applies a dilated 3x3 convolution followed by BatchNorm2d and ReLU.
Parameters:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
dilation (int): Dilation rate for the atrous convolution.
"""
def __init__(self, in_channels, out_channels, dilation):
...
class ASPPPooling(nn.Sequential):
"""
Global average pooling branch for ASPP.
Lines 27-39.
Applies adaptive average pooling to 1x1, followed by
1x1 Conv -> BN -> ReLU, then upsamples back to input spatial size.
"""
def __init__(self, in_channels, out_channels):
...
def forward(self, x):
...
class ASPP(nn.Module):
"""
Atrous Spatial Pyramid Pooling module.
Lines 42-70.
Combines:
- 1x1 convolution branch (no dilation)
- Multiple 3x3 atrous convolution branches at different dilation rates
- Global average pooling branch (ASPPPooling)
- Concatenation + projection via 1x1 conv + dropout
"""
def __init__(self, in_channels, atrous_rates):
...
def forward(self, x):
...
Import
I/O Contract
| Class |
Input |
Output |
Notes
|
DeepLabHead |
Feature map tensor (B, in_channels, H, W) |
Segmentation logits (B, num_classes, H, W) |
Full head: ASPP + classifier
|
ASPPConv |
Feature map tensor (B, C_in, H, W) |
Feature map tensor (B, C_out, H, W) |
Dilated 3x3 conv + BN + ReLU
|
ASPPPooling |
Feature map tensor (B, C_in, H, W) |
Feature map tensor (B, C_out, H, W) |
Global pooling + upsample
|
ASPP |
Feature map tensor (B, C_in, H, W) |
Feature map tensor (B, 256, H, W) |
Multi-scale feature aggregation
|
ASPP Dilation Rates
The ASPP module processes input through parallel branches with different dilation rates to capture multi-scale context:
| Branch |
Type |
Dilation |
Receptive Field
|
| 1 |
1x1 Convolution |
1 (no dilation) |
Local features
|
| 2 |
ASPPConv |
Rate from atrous_rates[0] |
Small-scale context
|
| 3 |
ASPPConv |
Rate from atrous_rates[1] |
Medium-scale context
|
| 4 |
ASPPConv |
Rate from atrous_rates[2] |
Large-scale context
|
| 5 |
ASPPPooling |
Global |
Image-level context
|
Usage Examples
Example 1: Constructing DeepLabHead
import torch.nn as nn
# Create a DeepLabHead with typical ResNet backbone output channels
head = DeepLabHead(
in_channels=2048, # ResNet-101 final layer channels
num_classes=21, # PASCAL VOC classes
atrous_rates=[12, 24, 36]
)
Example 2: ASPP Forward Pass
import torch
# Create ASPP module
aspp = ASPP(in_channels=2048, atrous_rates=[6, 12, 18])
# Simulated backbone feature map
feature_map = torch.randn(1, 2048, 32, 32)
# Multi-scale feature extraction
output = aspp(feature_map)
# output shape: (1, 256, 32, 32)
Related Pages