Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Facebookresearch Habitat lab VER InferenceWorker

From Leeroopedia
Knowledge Sources
Domains Embodied_AI, Reinforcement_Learning, Distributed_Training
Last Updated 2026-02-15 00:00 GMT

Overview

InferenceWorker runs policy inference in a dedicated process for the VER training system, batching observations from multiple environments, running the actor-critic forward pass, distributing actions back to environments, and writing experience into the rollout storage.

Description

This module implements the inference-side workers for VER (Variable Experience Rollout) training. It contains two primary classes:

InferenceWorkerProcess is an attrs-decorated process class that performs the core inference loop. Key responsibilities include:

  • Batching observations: Collects incoming environment indices from the inference queue, batches their observations from shared transfer buffers, and applies observation transforms.
  • Policy forward pass: Runs actor_critic.act() on the batched observations with recurrent hidden states and previous actions to produce actions, value predictions, and action log probabilities.
  • Static encoder: Optionally pre-computes visual features using a frozen visual encoder before the full policy forward pass, caching them in the observation dict.
  • Action distribution: Writes computed actions back to shared transfer buffers and enqueues step commands to the corresponding environment workers.
  • Rollout storage updates: Writes the full step data (observations, actions, masks, rewards, value predictions, hidden states, etc.) into the VERRolloutStorage. Supports both variable-experience mode (linear buffer with pointer tracking) and fixed-experience mode (structured step/env indexing).
  • Replay steps: After a rollout completes, the worker replays the final batch of experience to compute bootstrap value estimates for return computation.
  • Policy version tracking: Monitors the policy version and updates local weights when the learner has performed a gradient step (with overlapped learning, each worker maintains its own copy).
  • Adaptive batching: Dynamically adjusts the minimum and maximum number of requests to batch, and the minimum wait time between steps, based on a windowed running mean of step times.

InferenceWorker is a WorkerBase subclass providing the public API for setup: set_actor_critic_tensors(), set_rollouts(), and start().

The module also includes synchronization support through InferenceWorkerSync for coordinating multiple inference workers via locks and barriers, and RolloutEarlyEnds for early termination of rollouts in variable experience mode.

Usage

This module is used internally by the VER trainer. Multiple inference workers can run in parallel, each processing a subset of the environment observations. The number of inference workers is configured via config.habitat_baselines.rl.ver.num_inference_workers.

Code Reference

Source Location

Signature

class InferenceWorker(WorkerBase):
    def __init__(self, mp_ctx: BaseContext, *args, **kwargs): ...
    def set_actor_critic_tensors(self, actor_critic_tensors): ...
    def set_rollouts(self, rollouts): ...
    def start(self): ...

@attr.s(auto_attribs=True)
class InferenceWorkerProcess(ProcessBase):
    setup_queue: SimpleQueue
    inference_worker_idx: int
    num_inference_workers: int
    config: "DictConfig"
    queues: WorkerQueues
    iw_sync: InferenceWorkerSync
    _torch_transfer_buffers: TensorDict
    policy_name: str
    policy_args: Tuple
    device: torch.device
    rollout_ends: RolloutEarlyEnds
    ...
    def step(self) -> Tuple[bool, List[Tuple[int, int]]]: ...
    def finish_rollout(self): ...
    def try_one_step(self) -> bool: ...
    def run(self): ...

Import

from habitat_baselines.rl.ver.inference_worker import (
    InferenceWorker,
    InferenceWorkerProcess,
)

I/O Contract

Inputs

Name Type Required Description
mp_ctx BaseContext Yes Multiprocessing context
inference_worker_idx int Yes Index of this inference worker
num_inference_workers int Yes Total number of inference workers
config DictConfig Yes Full Habitat baselines configuration
queues WorkerQueues Yes Shared queue structure for inter-worker communication
iw_sync InferenceWorkerSync Yes Synchronization primitives (lock, barriers, events) shared across inference workers
_torch_transfer_buffers TensorDict Yes Shared memory tensor dict for passing observations and actions between environment and inference workers
policy_name str Yes Registered name of the policy to instantiate
policy_args Tuple Yes Arguments to pass to the policy's from_config method
device torch.device Yes Device to run inference on (typically CUDA)
rollout_ends RolloutEarlyEnds Yes Shared multiprocessing values for early rollout termination

Outputs

Name Type Description
actions torch.Tensor Computed actions written to shared transfer buffers for environment workers to execute
rollout data VERRolloutStorage Full step data (observations, actions, values, hidden states, etc.) written to the shared rollout storage
timing reports Timing Performance timing data sent to the report worker
preemption signals dict Step completion timestamps and indices sent to the preemption decider

Usage Examples

Basic Usage

import torch.multiprocessing as mp
from habitat_baselines.rl.ver.inference_worker import InferenceWorker
from habitat_baselines.rl.ver.worker_common import (
    InferenceWorkerSync,
    RolloutEarlyEnds,
    WorkerQueues,
)

mp_ctx = mp.get_context("forkserver")
iw_sync = InferenceWorkerSync(mp_ctx, num_inference_workers=2)
rollout_ends = RolloutEarlyEnds(mp_ctx)

inference_worker = InferenceWorker(
    mp_ctx,
    inference_worker_idx=0,
    num_inference_workers=2,
    config=config,
    queues=worker_queues,
    iw_sync=iw_sync,
    _torch_transfer_buffers=transfer_buffers,
    policy_name="PointNavResNetPolicy",
    policy_args=policy_args,
    device=torch.device("cuda:0"),
    rollout_ends=rollout_ends,
)

# Setup phase
inference_worker.set_actor_critic_tensors(actor_critic_tensors)
inference_worker.set_rollouts(rollouts)
inference_worker.start()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment