Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pytorch Serve DeepLabHead

From Leeroopedia

Overview

DeepLabHead implements the DeepLabV3 segmentation head architecture, including the Atrous Spatial Pyramid Pooling (ASPP) module with multiple dilation rates for multi-scale feature extraction. The module comprises four key classes: DeepLabHead, ASPPConv, ASPPPooling, and ASPP, all built on torch.nn.

Field Value
Page Type Implementation
Implementation Type API Doc
Domains Semantic_Segmentation, Computer_Vision
Knowledge Sources Pytorch_Serve
Workflow Model_Deployment
Last Updated 2026-02-13 18:52 GMT

Description

This module provides the segmentation head for DeepLabV3, a state-of-the-art semantic segmentation architecture. The core innovation is the Atrous Spatial Pyramid Pooling (ASPP) module, which applies parallel atrous (dilated) convolutions at multiple rates to capture features at different spatial scales. This multi-scale approach enables the model to segment objects of varying sizes within a single forward pass.

Architecture Components

  • DeepLabHead (lines 6-14): Top-level segmentation head that chains ASPP with a final classification convolution
  • ASPPConv (lines 17-24): Single atrous convolution block with batch normalization and ReLU activation
  • ASPPPooling (lines 27-39): Global average pooling branch that captures image-level context features
  • ASPP (lines 42-70): The full Atrous Spatial Pyramid Pooling module that combines multiple parallel branches

Code Reference

Source Location

File Lines Repository
examples/image_segmenter/deeplabv3/deeplabv3.py L1-70 pytorch/serve

Key Classes

class DeepLabHead(nn.Sequential):
    """
    DeepLabV3 segmentation head.
    Lines 6-14.

    Chains ASPP module with a 3x3 Conv -> BN -> ReLU -> 1x1 classifier.
    """

    def __init__(self, in_channels, num_classes, atrous_rates):
        ...


class ASPPConv(nn.Sequential):
    """
    Single atrous convolution block.
    Lines 17-24.

    Applies a dilated 3x3 convolution followed by BatchNorm2d and ReLU.

    Parameters:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        dilation (int): Dilation rate for the atrous convolution.
    """

    def __init__(self, in_channels, out_channels, dilation):
        ...


class ASPPPooling(nn.Sequential):
    """
    Global average pooling branch for ASPP.
    Lines 27-39.

    Applies adaptive average pooling to 1x1, followed by
    1x1 Conv -> BN -> ReLU, then upsamples back to input spatial size.
    """

    def __init__(self, in_channels, out_channels):
        ...

    def forward(self, x):
        ...


class ASPP(nn.Module):
    """
    Atrous Spatial Pyramid Pooling module.
    Lines 42-70.

    Combines:
    - 1x1 convolution branch (no dilation)
    - Multiple 3x3 atrous convolution branches at different dilation rates
    - Global average pooling branch (ASPPPooling)
    - Concatenation + projection via 1x1 conv + dropout
    """

    def __init__(self, in_channels, atrous_rates):
        ...

    def forward(self, x):
        ...

Import

import torch.nn as nn

I/O Contract

Class Input Output Notes
DeepLabHead Feature map tensor (B, in_channels, H, W) Segmentation logits (B, num_classes, H, W) Full head: ASPP + classifier
ASPPConv Feature map tensor (B, C_in, H, W) Feature map tensor (B, C_out, H, W) Dilated 3x3 conv + BN + ReLU
ASPPPooling Feature map tensor (B, C_in, H, W) Feature map tensor (B, C_out, H, W) Global pooling + upsample
ASPP Feature map tensor (B, C_in, H, W) Feature map tensor (B, 256, H, W) Multi-scale feature aggregation

ASPP Dilation Rates

The ASPP module processes input through parallel branches with different dilation rates to capture multi-scale context:

Branch Type Dilation Receptive Field
1 1x1 Convolution 1 (no dilation) Local features
2 ASPPConv Rate from atrous_rates[0] Small-scale context
3 ASPPConv Rate from atrous_rates[1] Medium-scale context
4 ASPPConv Rate from atrous_rates[2] Large-scale context
5 ASPPPooling Global Image-level context

Usage Examples

Example 1: Constructing DeepLabHead

import torch.nn as nn

# Create a DeepLabHead with typical ResNet backbone output channels
head = DeepLabHead(
    in_channels=2048,       # ResNet-101 final layer channels
    num_classes=21,         # PASCAL VOC classes
    atrous_rates=[12, 24, 36]
)

Example 2: ASPP Forward Pass

import torch

# Create ASPP module
aspp = ASPP(in_channels=2048, atrous_rates=[6, 12, 18])

# Simulated backbone feature map
feature_map = torch.randn(1, 2048, 32, 32)

# Multi-scale feature extraction
output = aspp(feature_map)
# output shape: (1, 256, 32, 32)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment