Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro PyroLRScheduler

From Leeroopedia


Property Value
Module pyro.optim.lr_scheduler
Source pyro/optim/lr_scheduler.py
Lines 64
Classes PyroLRScheduler
Parent Class pyro.optim.optim.PyroOptim
Dependencies pyro.optim.optim

Overview

PyroLRScheduler wraps PyTorch learning rate schedulers (from torch.optim.lr_scheduler) to work with Pyro's dynamic parameter management. In Pyro, parameters can be created on-the-fly during model execution, so a standard PyTorch scheduler that is bound to a fixed optimizer at construction time would not see newly created parameters. PyroLRScheduler solves this by creating new optimizer-scheduler pairs for each new parameter group.

The optim_args dictionary must contain a special 'optimizer' key pointing to the PyTorch optimizer class, and an 'optim_args' key with the optimizer's keyword arguments. Any remaining keys (e.g., 'gamma', 'step_size') are passed to the scheduler constructor.

Code Reference

Class: PyroLRScheduler

Constructor:

  • scheduler_constructor: A torch.optim.lr_scheduler class.
  • optim_args (dict): Must contain:
    • 'optimizer': PyTorch optimizer class (e.g., torch.optim.SGD).
    • 'optim_args': Dict of optimizer kwargs (e.g., {'lr': 0.01}).
    • Any additional keys are passed to the scheduler (e.g., 'gamma': 0.1).
  • clip_args (dict, optional): Gradient clipping arguments.

Methods:

  • __call__(params, *args, **kwargs): Updates parameters, creating new optimizer-scheduler pairs for any newly encountered parameters.
  • _get_optim(params): Creates a PyTorch optimizer and wraps it with the scheduler.
  • step(*args, **kwargs): Advances all schedulers by one step. Takes the same arguments as the underlying PyTorch scheduler (e.g., optional loss for ReduceLROnPlateau).

I/O Contract

Method Input Output
__init__ Scheduler class, optim_args: dict, optional clip_args PyroLRScheduler instance
__call__ params: List/ValuesView None (updates params in-place)
step Scheduler-specific args (e.g., loss) None

Usage Examples

import torch
import pyro
import pyro.optim

# ExponentialLR scheduler with SGD optimizer
scheduler = pyro.optim.ExponentialLR({
    'optimizer': torch.optim.SGD,
    'optim_args': {'lr': 0.01},
    'gamma': 0.1
})

svi = pyro.infer.SVI(model, guide, scheduler, loss=pyro.infer.Trace_ELBO())

for epoch in range(num_epochs):
    for minibatch in dataloader:
        svi.step(minibatch)
    scheduler.step()  # decay learning rate

# ReduceLROnPlateau (requires loss argument)
scheduler2 = pyro.optim.ReduceLROnPlateau({
    'optimizer': torch.optim.Adam,
    'optim_args': {'lr': 0.01},
    'patience': 5,
    'factor': 0.5,
})
for epoch in range(num_epochs):
    loss = svi.step(data)
    scheduler2.step(loss)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment