Implementation:Facebookresearch Habitat lab Updater
| Knowledge Sources | |
|---|---|
| Domains | Embodied_AI, Reinforcement_Learning |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
Updater is an abstract base class that defines the interface for policy update algorithms, providing methods for performing updates from rollout storage, managing optimizer state, and post-update hooks.
Description
Updater establishes the contract that all policy updaters must follow. The central abstract method update takes a Storage object containing collected rollout data and returns a dictionary of loss/metric names to float values. It also provides default (no-op) implementations for after_update (called after each update), get_resume_state (for serializing optimizer state), load_state_dict (for restoring optimizer state), and a lr_scheduler property that returns None by default. Concrete updaters (like PPO) override these methods with actual optimization logic.
Usage
Subclass Updater to implement a specific policy optimization algorithm (e.g., PPO, SAC). Register the concrete implementation with the baseline registry via baseline_registry.register_updater.
Code Reference
Source Location
- Repository: Facebookresearch_Habitat_lab
- File: habitat-baselines/habitat_baselines/rl/ppo/updater.py
- Lines: 11-39
Signature
class Updater(abc.ABC):
@abc.abstractmethod
def update(self, rollouts: Storage) -> Dict[str, float]:
@property
def lr_scheduler(self):
def after_update(self) -> None:
def get_resume_state(self) -> Dict[str, Any]:
def load_state_dict(self, state: Dict[str, Any]) -> None:
Import
from habitat_baselines.rl.ppo.updater import Updater
I/O Contract
Inputs (update)
| Name | Type | Required | Description |
|---|---|---|---|
| rollouts | Storage | Yes | Rollout storage containing collected experience data for the update |
Outputs (update)
| Name | Type | Description |
|---|---|---|
| losses | Dict[str, float] | Dictionary mapping loss/metric names to their float values for the current update |
Key Methods
update (abstract)
def update(self, rollouts: Storage) -> Dict[str, float]
Performs a policy update from data in the storage object. Must be implemented by subclasses.
after_update
def after_update(self) -> None
Hook called after the policy update completes. Default implementation is a no-op.
get_resume_state
def get_resume_state(self) -> Dict[str, Any]
Returns the optimizer state for checkpointing. Default returns None.
load_state_dict
def load_state_dict(self, state: Dict[str, Any]) -> None
Restores the optimizer state from a checkpoint. Default is a no-op.
Usage Examples
Basic Usage
from habitat_baselines.rl.ppo.updater import Updater
from habitat_baselines.common.storage import Storage
from habitat_baselines.common.baseline_registry import baseline_registry
@baseline_registry.register_updater
class MyPPOUpdater(Updater):
def __init__(self, actor_critic, clip_param, ppo_epoch, num_mini_batch, lr):
self.actor_critic = actor_critic
self.clip_param = clip_param
self.ppo_epoch = ppo_epoch
self.num_mini_batch = num_mini_batch
self.optimizer = torch.optim.Adam(actor_critic.parameters(), lr=lr)
def update(self, rollouts: Storage) -> dict:
total_loss = 0.0
for epoch in range(self.ppo_epoch):
for batch in rollouts.recurrent_generator(self.num_mini_batch):
loss = self._compute_loss(batch)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
return {"ppo_loss": total_loss}
def get_resume_state(self):
return {"optimizer": self.optimizer.state_dict()}
def load_state_dict(self, state):
self.optimizer.load_state_dict(state["optimizer"])