Implementation:OpenGVLab InternVL DDP Communication Hooks
| Knowledge Sources | |
|---|---|
| Domains | Distributed Training, Gradient Compression, Classification |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Provides custom DDP (DistributedDataParallel) communication hooks that compress gradient tensors to half-precision during allreduce operations to reduce inter-GPU communication bandwidth.
Description
This module implements five functions for gradient communication in distributed training:
- allreduce_hook -- Standard gradient averaging across workers via asynchronous allreduce. Divides the gradient tensor by the process group size before allreduce to avoid overflow, especially for FP16.
- fp16_compress_hook -- Casts the gradient bucket buffer to torch.float16 before allreduce, then decompresses back to the original data type (e.g., float32) via in-place copy to minimize peak memory usage.
- bf16_compress_hook -- Same as fp16_compress_hook but uses torch.bfloat16. Requires NCCL version later than 2.9.6.
- fp16_compress_wrapper -- A higher-order function that wraps any existing DDP communication hook with FP16 compression. Casts the bucket buffer to float16 before delegating to the wrapped hook, then decompresses the result.
- bf16_compress_wrapper -- Same as fp16_compress_wrapper but uses bfloat16.
All operations are asynchronous via torch.futures.Future, enabling overlap of communication with computation. In-place decompression (using copy_) follows the pattern from PyTorch issue #45968 to reduce peak memory.
Usage
Use these hooks when training large models (like InternViT-6B) across multiple GPUs or nodes to reduce gradient communication bandwidth by approximately 50%. Register a hook via ddp_model.register_comm_hook(process_group, hook_fn). Use the wrapper variants to compose compression with other hooks like PowerSGD.
Code Reference
Source Location
- Repository: OpenGVLab_InternVL
- File: classification/ddp_hooks.py
- Lines: 1-182
Signature
def allreduce_hook(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]: ...
def fp16_compress_hook(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]: ...
def bf16_compress_hook(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]: ...
def fp16_compress_wrapper(
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: ...
def bf16_compress_wrapper(
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: ...
Import
from classification.ddp_hooks import fp16_compress_hook, bf16_compress_hook
from classification.ddp_hooks import fp16_compress_wrapper, bf16_compress_wrapper
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| process_group | dist.ProcessGroup | Yes | The process group for distributed communication (None defaults to WORLD) |
| bucket | dist.GradBucket | Yes | The gradient bucket containing the tensor buffer to be communicated |
Outputs
| Name | Type | Description |
|---|---|---|
| future | torch.futures.Future[torch.Tensor] | A future that resolves to the allreduced (and decompressed) gradient tensor |
Usage Examples
Basic Usage
import torch.distributed as dist
from classification.ddp_hooks import fp16_compress_hook
# Register FP16 gradient compression on a DDP model
ddp_model.register_comm_hook(process_group, fp16_compress_hook)
Composing with Another Hook
from classification.ddp_hooks import bf16_compress_wrapper
# Wrap an existing hook (e.g., PowerSGD) with BF16 compression
ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))