Overview
Diffusion_Fast_Handler is an optimized TorchServe handler for serving diffusion models with maximum throughput. It extends BaseHandler and applies multiple performance optimizations during initialization including torch.compile(), fused projections, quantization, and BF16 precision. All lifecycle methods use @timed decorators for performance monitoring.
Description
The DiffusionFastHandler class is an optimized diffusion model handler that applies a suite of compile-time and runtime optimizations to maximize inference throughput. During initialization, the handler loads a diffusion pipeline and progressively applies optimizations: torch.compile() for graph optimization, fused attention projections for memory efficiency, dynamic quantization for reduced precision arithmetic, and BF16 (bfloat16) casting for Ampere+ GPU acceleration.
Key Responsibilities
- Pipeline Loading: Loads the diffusion model pipeline from the model archive
- Compilation Optimization: Applies
torch.compile() for JIT graph optimization
- Fused Projections: Enables fused QKV attention projections to reduce memory bandwidth
- Quantization: Applies dynamic quantization for reduced precision computation
- BF16 Precision: Casts model weights to BF16 for Ampere+ tensor core acceleration
- Performance Monitoring: Uses
@timed decorators on all handler methods
Optimization Stack
| Optimization |
Stage |
Benefit
|
torch.compile() |
Initialization |
Graph-level operator fusion and optimization
|
| Fused Projections |
Initialization |
Reduced memory bandwidth for attention layers
|
| Quantization |
Initialization |
Lower precision arithmetic for faster compute
|
| BF16 Casting |
Initialization |
Tensor core acceleration on Ampere+ GPUs
|
Code Reference
Source Location
| File |
Lines |
Repository
|
examples/large_models/diffusion_fast/diffusion_fast_handler.py |
L1-133 |
pytorch/serve
|
Key Class
class DiffusionFastHandler(BaseHandler):
"""
Optimized diffusion model handler with compilation and quantization.
Lines 15-134.
All methods decorated with @timed for performance monitoring.
"""
def initialize(self, context):
"""
Load pipeline and apply performance optimizations.
Applies the following optimization stack in order:
1. Load diffusion pipeline from model archive
2. Apply torch.compile() for graph optimization
3. Enable fused attention projections
4. Apply dynamic quantization
5. Cast to BF16 precision
Parameters:
context: TorchServe context with system_properties and manifest.
"""
...
def preprocess(self, data):
"""
Extract text prompts from request data.
Parses incoming requests and extracts text prompts.
Uses batch_size=1 for optimal pipeline performance.
Parameters:
data (list): List of request input dicts.
Returns:
list: Text prompt strings (batch_size=1).
"""
...
def inference(self, data, *args, **kwargs):
"""
Run the optimized diffusion pipeline.
Executes the compiled and quantized pipeline
for image generation from text prompts.
Parameters:
data (list): Text prompts from preprocess.
Returns:
Pipeline output (generated images).
"""
...
def postprocess(self, data):
"""
Convert generated images to numpy arrays.
Parameters:
data: Pipeline output images.
Returns:
list: Numpy array representations of generated images.
"""
...
Import
from ts.torch_handler.base_handler import BaseHandler
I/O Contract
| Method |
Input |
Output |
Notes
|
initialize(context) |
Context with system_properties, manifest |
None (sets self.pipeline) |
Applies full optimization stack
|
preprocess(data) |
list of request dicts with text prompts |
list of prompt strings |
batch_size=1 for pipeline
|
inference(data) |
list of prompt strings |
Pipeline output (images) |
Runs compiled + quantized pipeline
|
postprocess(data) |
Pipeline output |
list of numpy arrays |
Converts for serialization
|
Performance Decorators
All handler methods use the @timed decorator, which measures and logs execution time for each stage of the inference pipeline:
| Method |
Metric Tracked
|
initialize |
Model load + optimization compilation time
|
preprocess |
Input parsing and prompt extraction time
|
inference |
Diffusion pipeline execution time
|
postprocess |
Output conversion time
|
Usage Examples
Example 1: Handler Initialization with Optimizations
# During initialization, the handler applies a multi-stage optimization stack:
handler = DiffusionFastHandler()
handler.initialize(context)
# After initialization:
# - Pipeline is loaded and compiled with torch.compile()
# - Attention projections are fused for memory efficiency
# - Dynamic quantization is applied
# - Model weights are cast to BF16
Example 2: Optimized Inference
# Request with a text prompt (batch_size=1)
data = [{"body": "A hyperrealistic portrait of an astronaut in a garden"}]
prompts = handler.preprocess(data) # @timed
images = handler.inference(prompts) # @timed - runs compiled pipeline
output = handler.postprocess(images) # @timed
# output contains numpy array of the generated image
Example 3: curl Request
curl -X POST http://localhost:8080/predictions/diffusion_fast \
-H "Content-Type: application/json" \
-d '{"data": "A cyberpunk city with neon lights and flying cars"}'
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.