Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Facebookresearch Habitat lab Updater

From Leeroopedia
Revision as of 12:36, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Facebookresearch_Habitat_lab_Updater.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Knowledge Sources
Domains Embodied_AI, Reinforcement_Learning
Last Updated 2026-02-15 00:00 GMT

Overview

Updater is an abstract base class that defines the interface for policy update algorithms, providing methods for performing updates from rollout storage, managing optimizer state, and post-update hooks.

Description

Updater establishes the contract that all policy updaters must follow. The central abstract method update takes a Storage object containing collected rollout data and returns a dictionary of loss/metric names to float values. It also provides default (no-op) implementations for after_update (called after each update), get_resume_state (for serializing optimizer state), load_state_dict (for restoring optimizer state), and a lr_scheduler property that returns None by default. Concrete updaters (like PPO) override these methods with actual optimization logic.

Usage

Subclass Updater to implement a specific policy optimization algorithm (e.g., PPO, SAC). Register the concrete implementation with the baseline registry via baseline_registry.register_updater.

Code Reference

Source Location

Signature

class Updater(abc.ABC):
    @abc.abstractmethod
    def update(self, rollouts: Storage) -> Dict[str, float]:

    @property
    def lr_scheduler(self):

    def after_update(self) -> None:

    def get_resume_state(self) -> Dict[str, Any]:

    def load_state_dict(self, state: Dict[str, Any]) -> None:

Import

from habitat_baselines.rl.ppo.updater import Updater

I/O Contract

Inputs (update)

Name Type Required Description
rollouts Storage Yes Rollout storage containing collected experience data for the update

Outputs (update)

Name Type Description
losses Dict[str, float] Dictionary mapping loss/metric names to their float values for the current update

Key Methods

update (abstract)

def update(self, rollouts: Storage) -> Dict[str, float]

Performs a policy update from data in the storage object. Must be implemented by subclasses.

after_update

def after_update(self) -> None

Hook called after the policy update completes. Default implementation is a no-op.

get_resume_state

def get_resume_state(self) -> Dict[str, Any]

Returns the optimizer state for checkpointing. Default returns None.

load_state_dict

def load_state_dict(self, state: Dict[str, Any]) -> None

Restores the optimizer state from a checkpoint. Default is a no-op.

Usage Examples

Basic Usage

from habitat_baselines.rl.ppo.updater import Updater
from habitat_baselines.common.storage import Storage
from habitat_baselines.common.baseline_registry import baseline_registry

@baseline_registry.register_updater
class MyPPOUpdater(Updater):
    def __init__(self, actor_critic, clip_param, ppo_epoch, num_mini_batch, lr):
        self.actor_critic = actor_critic
        self.clip_param = clip_param
        self.ppo_epoch = ppo_epoch
        self.num_mini_batch = num_mini_batch
        self.optimizer = torch.optim.Adam(actor_critic.parameters(), lr=lr)

    def update(self, rollouts: Storage) -> dict:
        total_loss = 0.0
        for epoch in range(self.ppo_epoch):
            for batch in rollouts.recurrent_generator(self.num_mini_batch):
                loss = self._compute_loss(batch)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()
        return {"ppo_loss": total_loss}

    def get_resume_state(self):
        return {"optimizer": self.optimizer.state_dict()}

    def load_state_dict(self, state):
        self.optimizer.load_state_dict(state["optimizer"])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment