Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:OpenGVLab InternVL DDP Communication Hooks

From Leeroopedia


Knowledge Sources
Domains Distributed Training, Gradient Compression, Classification
Last Updated 2026-02-07 14:00 GMT

Overview

Provides custom DDP (DistributedDataParallel) communication hooks that compress gradient tensors to half-precision during allreduce operations to reduce inter-GPU communication bandwidth.

Description

This module implements five functions for gradient communication in distributed training:

  • allreduce_hook -- Standard gradient averaging across workers via asynchronous allreduce. Divides the gradient tensor by the process group size before allreduce to avoid overflow, especially for FP16.
  • fp16_compress_hook -- Casts the gradient bucket buffer to torch.float16 before allreduce, then decompresses back to the original data type (e.g., float32) via in-place copy to minimize peak memory usage.
  • bf16_compress_hook -- Same as fp16_compress_hook but uses torch.bfloat16. Requires NCCL version later than 2.9.6.
  • fp16_compress_wrapper -- A higher-order function that wraps any existing DDP communication hook with FP16 compression. Casts the bucket buffer to float16 before delegating to the wrapped hook, then decompresses the result.
  • bf16_compress_wrapper -- Same as fp16_compress_wrapper but uses bfloat16.

All operations are asynchronous via torch.futures.Future, enabling overlap of communication with computation. In-place decompression (using copy_) follows the pattern from PyTorch issue #45968 to reduce peak memory.

Usage

Use these hooks when training large models (like InternViT-6B) across multiple GPUs or nodes to reduce gradient communication bandwidth by approximately 50%. Register a hook via ddp_model.register_comm_hook(process_group, hook_fn). Use the wrapper variants to compose compression with other hooks like PowerSGD.

Code Reference

Source Location

Signature

def allreduce_hook(
    process_group: dist.ProcessGroup,
    bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]: ...

def fp16_compress_hook(
    process_group: dist.ProcessGroup,
    bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]: ...

def bf16_compress_hook(
    process_group: dist.ProcessGroup,
    bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]: ...

def fp16_compress_wrapper(
    hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: ...

def bf16_compress_wrapper(
    hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: ...

Import

from classification.ddp_hooks import fp16_compress_hook, bf16_compress_hook
from classification.ddp_hooks import fp16_compress_wrapper, bf16_compress_wrapper

I/O Contract

Inputs

Name Type Required Description
process_group dist.ProcessGroup Yes The process group for distributed communication (None defaults to WORLD)
bucket dist.GradBucket Yes The gradient bucket containing the tensor buffer to be communicated

Outputs

Name Type Description
future torch.futures.Future[torch.Tensor] A future that resolves to the allreduced (and decompressed) gradient tensor

Usage Examples

Basic Usage

import torch.distributed as dist
from classification.ddp_hooks import fp16_compress_hook

# Register FP16 gradient compression on a DDP model
ddp_model.register_comm_hook(process_group, fp16_compress_hook)

Composing with Another Hook

from classification.ddp_hooks import bf16_compress_wrapper

# Wrap an existing hook (e.g., PowerSGD) with BF16 compression
ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment