Implementation:Pyro ppl Pyro PyroLRScheduler
| Property | Value |
|---|---|
| Module | pyro.optim.lr_scheduler
|
| Source | pyro/optim/lr_scheduler.py |
| Lines | 64 |
| Classes | PyroLRScheduler
|
| Parent Class | pyro.optim.optim.PyroOptim
|
| Dependencies | pyro.optim.optim
|
Overview
PyroLRScheduler wraps PyTorch learning rate schedulers (from torch.optim.lr_scheduler) to work with Pyro's dynamic parameter management. In Pyro, parameters can be created on-the-fly during model execution, so a standard PyTorch scheduler that is bound to a fixed optimizer at construction time would not see newly created parameters. PyroLRScheduler solves this by creating new optimizer-scheduler pairs for each new parameter group.
The optim_args dictionary must contain a special 'optimizer' key pointing to the PyTorch optimizer class, and an 'optim_args' key with the optimizer's keyword arguments. Any remaining keys (e.g., 'gamma', 'step_size') are passed to the scheduler constructor.
Code Reference
Class: PyroLRScheduler
Constructor:
scheduler_constructor: Atorch.optim.lr_schedulerclass.optim_args(dict): Must contain:'optimizer': PyTorch optimizer class (e.g.,torch.optim.SGD).'optim_args': Dict of optimizer kwargs (e.g.,{'lr': 0.01}).- Any additional keys are passed to the scheduler (e.g.,
'gamma': 0.1).
clip_args(dict, optional): Gradient clipping arguments.
Methods:
__call__(params, *args, **kwargs): Updates parameters, creating new optimizer-scheduler pairs for any newly encountered parameters._get_optim(params): Creates a PyTorch optimizer and wraps it with the scheduler.step(*args, **kwargs): Advances all schedulers by one step. Takes the same arguments as the underlying PyTorch scheduler (e.g., optionallossforReduceLROnPlateau).
I/O Contract
| Method | Input | Output |
|---|---|---|
__init__ |
Scheduler class, optim_args: dict, optional clip_args |
PyroLRScheduler instance
|
__call__ |
params: List/ValuesView |
None (updates params in-place) |
step |
Scheduler-specific args (e.g., loss) | None |
Usage Examples
import torch
import pyro
import pyro.optim
# ExponentialLR scheduler with SGD optimizer
scheduler = pyro.optim.ExponentialLR({
'optimizer': torch.optim.SGD,
'optim_args': {'lr': 0.01},
'gamma': 0.1
})
svi = pyro.infer.SVI(model, guide, scheduler, loss=pyro.infer.Trace_ELBO())
for epoch in range(num_epochs):
for minibatch in dataloader:
svi.step(minibatch)
scheduler.step() # decay learning rate
# ReduceLROnPlateau (requires loss argument)
scheduler2 = pyro.optim.ReduceLROnPlateau({
'optimizer': torch.optim.Adam,
'optim_args': {'lr': 0.01},
'patience': 5,
'factor': 0.5,
})
for epoch in range(num_epochs):
loss = svi.step(data)
scheduler2.step(loss)
Related Pages
- Pyro_ppl_Pyro_AdagradRMSProp -- Custom optimizer for ADVI
- Pyro_ppl_Pyro_DCTAdam -- DCT-augmented optimizer
- Pyro_ppl_Pyro_MultiOptimizer -- Higher-order optimization framework