Implementation:Facebookresearch Habitat lab VER InferenceWorker
| Knowledge Sources | |
|---|---|
| Domains | Embodied_AI, Reinforcement_Learning, Distributed_Training |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
InferenceWorker runs policy inference in a dedicated process for the VER training system, batching observations from multiple environments, running the actor-critic forward pass, distributing actions back to environments, and writing experience into the rollout storage.
Description
This module implements the inference-side workers for VER (Variable Experience Rollout) training. It contains two primary classes:
InferenceWorkerProcess is an attrs-decorated process class that performs the core inference loop. Key responsibilities include:
- Batching observations: Collects incoming environment indices from the inference queue, batches their observations from shared transfer buffers, and applies observation transforms.
- Policy forward pass: Runs
actor_critic.act()on the batched observations with recurrent hidden states and previous actions to produce actions, value predictions, and action log probabilities. - Static encoder: Optionally pre-computes visual features using a frozen visual encoder before the full policy forward pass, caching them in the observation dict.
- Action distribution: Writes computed actions back to shared transfer buffers and enqueues step commands to the corresponding environment workers.
- Rollout storage updates: Writes the full step data (observations, actions, masks, rewards, value predictions, hidden states, etc.) into the VERRolloutStorage. Supports both variable-experience mode (linear buffer with pointer tracking) and fixed-experience mode (structured step/env indexing).
- Replay steps: After a rollout completes, the worker replays the final batch of experience to compute bootstrap value estimates for return computation.
- Policy version tracking: Monitors the policy version and updates local weights when the learner has performed a gradient step (with overlapped learning, each worker maintains its own copy).
- Adaptive batching: Dynamically adjusts the minimum and maximum number of requests to batch, and the minimum wait time between steps, based on a windowed running mean of step times.
InferenceWorker is a WorkerBase subclass providing the public API for setup: set_actor_critic_tensors(), set_rollouts(), and start().
The module also includes synchronization support through InferenceWorkerSync for coordinating multiple inference workers via locks and barriers, and RolloutEarlyEnds for early termination of rollouts in variable experience mode.
Usage
This module is used internally by the VER trainer. Multiple inference workers can run in parallel, each processing a subset of the environment observations. The number of inference workers is configured via config.habitat_baselines.rl.ver.num_inference_workers.
Code Reference
Source Location
- Repository: Facebookresearch_Habitat_lab
- File: habitat-baselines/habitat_baselines/rl/ver/inference_worker.py
- Lines: 1-576
Signature
class InferenceWorker(WorkerBase):
def __init__(self, mp_ctx: BaseContext, *args, **kwargs): ...
def set_actor_critic_tensors(self, actor_critic_tensors): ...
def set_rollouts(self, rollouts): ...
def start(self): ...
@attr.s(auto_attribs=True)
class InferenceWorkerProcess(ProcessBase):
setup_queue: SimpleQueue
inference_worker_idx: int
num_inference_workers: int
config: "DictConfig"
queues: WorkerQueues
iw_sync: InferenceWorkerSync
_torch_transfer_buffers: TensorDict
policy_name: str
policy_args: Tuple
device: torch.device
rollout_ends: RolloutEarlyEnds
...
def step(self) -> Tuple[bool, List[Tuple[int, int]]]: ...
def finish_rollout(self): ...
def try_one_step(self) -> bool: ...
def run(self): ...
Import
from habitat_baselines.rl.ver.inference_worker import (
InferenceWorker,
InferenceWorkerProcess,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| mp_ctx | BaseContext | Yes | Multiprocessing context |
| inference_worker_idx | int | Yes | Index of this inference worker |
| num_inference_workers | int | Yes | Total number of inference workers |
| config | DictConfig | Yes | Full Habitat baselines configuration |
| queues | WorkerQueues | Yes | Shared queue structure for inter-worker communication |
| iw_sync | InferenceWorkerSync | Yes | Synchronization primitives (lock, barriers, events) shared across inference workers |
| _torch_transfer_buffers | TensorDict | Yes | Shared memory tensor dict for passing observations and actions between environment and inference workers |
| policy_name | str | Yes | Registered name of the policy to instantiate |
| policy_args | Tuple | Yes | Arguments to pass to the policy's from_config method |
| device | torch.device | Yes | Device to run inference on (typically CUDA) |
| rollout_ends | RolloutEarlyEnds | Yes | Shared multiprocessing values for early rollout termination |
Outputs
| Name | Type | Description |
|---|---|---|
| actions | torch.Tensor | Computed actions written to shared transfer buffers for environment workers to execute |
| rollout data | VERRolloutStorage | Full step data (observations, actions, values, hidden states, etc.) written to the shared rollout storage |
| timing reports | Timing | Performance timing data sent to the report worker |
| preemption signals | dict | Step completion timestamps and indices sent to the preemption decider |
Usage Examples
Basic Usage
import torch.multiprocessing as mp
from habitat_baselines.rl.ver.inference_worker import InferenceWorker
from habitat_baselines.rl.ver.worker_common import (
InferenceWorkerSync,
RolloutEarlyEnds,
WorkerQueues,
)
mp_ctx = mp.get_context("forkserver")
iw_sync = InferenceWorkerSync(mp_ctx, num_inference_workers=2)
rollout_ends = RolloutEarlyEnds(mp_ctx)
inference_worker = InferenceWorker(
mp_ctx,
inference_worker_idx=0,
num_inference_workers=2,
config=config,
queues=worker_queues,
iw_sync=iw_sync,
_torch_transfer_buffers=transfer_buffers,
policy_name="PointNavResNetPolicy",
policy_args=policy_args,
device=torch.device("cuda:0"),
rollout_ends=rollout_ends,
)
# Setup phase
inference_worker.set_actor_critic_tensors(actor_critic_tensors)
inference_worker.set_rollouts(rollouts)
inference_worker.start()