Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Facebookresearch Habitat lab VER PreemptionDecider

From Leeroopedia
Knowledge Sources
Domains Embodied_AI, Reinforcement_Learning, Distributed_Training
Last Updated 2026-02-15 00:00 GMT

Overview

PreemptionDecider is a dedicated worker process in the VER training system that computes the optimal rollout preemption schedule to maximize throughput when training with VER and DD-PPO across heterogeneous GPU workers.

Description

The PreemptionDeciderProcess class runs in its own process and attempts to approximate the optimal preemption time for each rollout. The optimization goal is:

argmax_S S / (Time(S) + LT)

where S is the number of steps of experience to collect, Time(S) is the time to collect S steps, and LT is the learner time (time to perform a gradient update).

The key challenge is approximating Time(S). The decider does this by:

  1. Tracking per-environment step times: Each inference worker sends a report after completing a step, including a timestamp. The decider maintains a WindowedRunningMean per environment to estimate average step times.
  2. Computing candidate rollout lengths: From the per-environment step times, it constructs a (Workers x Environments x MaxSteps) array of rollout times, bins them, and evaluates each candidate length by computing the expected steps-per-second (SPS).
  3. Selecting the optimal length: The candidate length that maximizes SPS is selected, subject to constraints on maximum steps per worker and total steps across all workers.
  4. Distributed coordination: When training with multiple GPU workers (DD-PPO), the decider uses torch.distributed operations (gather, all_reduce, broadcast) to share step time estimates and synchronize the preemption schedule across all workers.

The decider also tracks:

  • Learner time: A windowed running mean of learning times, used in the throughput calculation.
  • Preemption error: The difference between predicted and actual rollout times, added as a correction factor.
  • Rollout end signals: Writes the computed preemption time and step count to shared multiprocessing Value objects (rollout_ends) that inference workers read to decide when to stop collecting experience.

PreemptionDeciderWorker is the WorkerBase subclass providing the public API with methods start_rollout(), end_rollout(), and learner_time().

Usage

This module is used internally by the VER trainer to optimize experience collection throughput. It is particularly important in distributed settings where different GPU workers may have different environment speeds, and collecting a fixed number of steps would leave fast workers idle while slow workers catch up.

Code Reference

Source Location

Signature

class PreemptionDeciderWorker(WorkerBase):
    def __init__(
        self,
        mp_ctx: BaseContext,
        hostname: str,
        port: int,
        world_rank: int,
        world_size: int,
        config: "DictConfig",
        queues: WorkerQueues,
        my_t_zero: float,
    ): ...
    def start_rollout(self): ...
    def end_rollout(self, num_next_steps): ...
    def learner_time(self, learning_time): ...

@attr.s(auto_attribs=True)
class PreemptionDeciderProcess(ProcessBase):
    hostname: str
    port: int
    world_rank: int
    world_size: int
    config: "DictConfig"
    queues: WorkerQueues
    my_t_zero: float
    rollout_ends: RolloutEarlyEnds
    ...
    def update(self, num_next_steps: int): ...
    def policy_step(self, data): ...
    def start_rollout(self, start_time: float): ...
    def end_rollout(self, end_steps_time: float, num_next_steps: int): ...
    def learner_time(self, learner_time: float): ...
    def run(self): ...

Import

from habitat_baselines.rl.ver.preemption_decider import (
    PreemptionDeciderWorker,
    PreemptionDeciderProcess,
)

I/O Contract

Inputs

Name Type Required Description
mp_ctx BaseContext Yes Multiprocessing context
hostname str Yes Hostname for distributed communication (gloo backend)
port int Yes Port for distributed communication
world_rank int Yes Rank of this worker in the distributed group
world_size int Yes Total number of distributed workers
config DictConfig Yes Full Habitat baselines configuration
queues WorkerQueues Yes Shared queue structure including the preemption_decider queue
my_t_zero float Yes Reference start time (perf_counter) for this process

Outputs

Name Type Description
rollout_ends.time multiprocessing.Value The computed optimal rollout end time; -1.0 when not yet determined
rollout_ends.steps multiprocessing.Value The computed optimal number of steps to collect; -1.0 when not yet determined
report data dict Statistics sent to the report worker: real_steps_collected, expected_steps_collected, real_rollout_time, expected_rollout_time

Usage Examples

Basic Usage

import time
import torch.multiprocessing as mp
from habitat_baselines.rl.ver.preemption_decider import PreemptionDeciderWorker
from habitat_baselines.rl.ver.worker_common import WorkerQueues

mp_ctx = mp.get_context("forkserver")
queues = WorkerQueues(num_environments=4, mp_ctx=mp_ctx)
my_t_zero = time.perf_counter()

preemption_decider = PreemptionDeciderWorker(
    mp_ctx,
    hostname="localhost",
    port=8738,
    world_rank=0,
    world_size=1,
    config=config,
    queues=queues,
    my_t_zero=my_t_zero,
)

# During training loop:
preemption_decider.start_rollout()

# ... experience collection happens ...

preemption_decider.end_rollout(num_next_steps=num_envs * (num_steps + 1))
preemption_decider.learner_time(learning_time_seconds)

# Access rollout end signals from inference workers:
# preemption_decider.rollout_ends.time.value
# preemption_decider.rollout_ends.steps.value

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment