Implementation:Allenai Open instruct WorkerWrap
| Type | Class (vLLM Worker Extension) |
|---|---|
| Source | open_instruct/vllm_utils_workerwrap.py:L4-87
|
| Dependencies | vllm, torch, torch.distributed, ray.util.collective (optional) |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete vLLM worker extension for receiving model weight updates from DeepSpeed learners via NCCL broadcast or CUDA IPC, provided by the Open Instruct library.
Description
WorkerWrap is a class that is injected into vLLM's worker processes via the worker_extension_cls parameter. It adds two capabilities that vLLM workers do not have natively:
- Process group initialization: The
init_process_group()method creates a torch distributed process group that links the vLLM worker to the training rank 0 process. This enables NCCL collective operations (specifically broadcast) between the learner and inference workers.
- Weight update: The
update_weight()method receives a parameter name, dtype, and shape, allocates an empty CUDA tensor, receives the parameter data via NCCL broadcast from rank 0, and loads it into the vLLM model runner viamodel.load_weights().
- CUDA IPC weight update: The
update_weight_cuda_ipc()method provides an alternative path for same-node weight transfer using CUDA IPC handles, which avoids the overhead of NCCL for co-located GPUs.
The class runs inside vLLM worker processes (which may be subprocesses spawned by the "mp" distributed executor backend), so all imports are deferred to method bodies to avoid import errors in the worker environment.
Usage
This class is not imported directly by user code. Instead, it is referenced by its fully qualified class name in the vLLM engine configuration:
worker_extension_cls="open_instruct.vllm_utils_workerwrap.WorkerWrap"
vLLM will dynamically load and attach this class to its worker instances.
Code Reference
Source Location
- Repository: Open Instruct
- File:
open_instruct/vllm_utils_workerwrap.py
Signature
class WorkerWrap:
def init_process_group(
self,
master_address: str,
master_port: int,
rank_offset: int,
world_size: int,
group_name: str,
backend: str = "nccl",
use_ray: bool = False,
timeout_minutes: int = 120,
) -> None:
"""Init torch process group for model weights update."""
...
def update_weight(
self,
name: str,
dtype: str,
shape: tuple[int, ...],
empty_cache: bool = False,
) -> None:
"""Receive a weight tensor via broadcast and load into the model."""
...
def update_weight_cuda_ipc(
self,
name: str,
dtype: str,
shape: tuple[int, ...],
ipc_handles: list = None,
empty_cache: bool = False,
) -> None:
"""Receive a weight tensor via CUDA IPC and load into the model."""
...
Import
# Not imported directly; referenced as a string in engine configuration:
# worker_extension_cls="open_instruct.vllm_utils_workerwrap.WorkerWrap"
I/O Contract
Inputs
| Name | Type | Description |
|---|---|---|
master_address |
str |
IP address of the learner rank 0 process. |
master_port |
int |
Port for the distributed process group. |
rank_offset |
int |
Offset added to the worker's torch rank to compute its rank in the combined group. |
world_size |
int |
Total size of the combined group (1 learner + num_engines * TP workers). |
name |
str |
Fully qualified parameter name (e.g., model.layers.0.self_attn.q_proj.weight).
|
dtype |
str |
String representation of the parameter dtype; must match the model's dtype. |
shape |
tuple[int, ...] |
Shape of the parameter tensor to receive. |
ipc_handles |
list |
CUDA IPC memory handles for same-node transfer (used by update_weight_cuda_ipc).
|
Outputs
| Name | Type | Description |
|---|---|---|
| Side effect | Model weight update | The named parameter in the vLLM model runner is updated in place. |
Usage Examples
# This class is used indirectly through the LLMRayActor:
# 1. Engine creation passes the worker extension class
engine = create_vllm_engines(
...,
# This tells vLLM to attach WorkerWrap to each worker
)
# 2. The learner initializes the process group
engine.init_process_group.remote(
master_address="10.0.0.1",
master_port=29500,
rank_offset=1, # vLLM workers start at rank 1
world_size=3, # 1 learner + 2 vLLM workers
group_name="openrlhf",
backend="nccl",
)
# 3. After each training step, weights are broadcast
# (handled by broadcast_weights_to_vllm in vllm_utils.py)
for name, param in model.named_parameters():
engine.update_weight.remote(name, str(param.dtype), param.shape)
torch.distributed.broadcast(param.data, 0, group=model_update_group)