Implementation:Pyro ppl Pyro DCTAdam
| Property | Value |
|---|---|
| Module | pyro.optim.dct_adam
|
| Source | pyro/optim/dct_adam.py |
| Lines | 215 |
| Classes | DCTAdam
|
| Parent Class | torch.optim.optimizer.Optimizer
|
| Status | EXPERIMENTAL |
| Dependencies | torch, pyro.ops.tensor_utils (dct, idct, next_fast_len)
|
Overview
DCTAdam is an experimental optimizer that augments Adam with Discrete Cosine Transform (DCT) processing for time-series parameters. For most parameters, it behaves like ClippedAdam. However, when a parameter has a ._pyro_dct_dim attribute indicating a time dimension, DCTAdam creates a secondary optimization in the frequency domain.
The algorithm:
- Forward-transforms the gradient to the frequency domain using DCT.
- Concatenates time-domain and frequency-domain representations.
- Applies standard Adam updates in this augmented space.
- Inverse-transforms back to the time domain.
This dual-domain approach is particularly useful for time series models where parameters exhibit both local and global temporal structure.
The optimizer also supports subsample-aware updates: when a parameter has a ._pyro_subsample attribute, only the subsampled entries are updated, improving efficiency for large datasets.
Code Reference
Class: DCTAdam
Constructor:
params: Iterable of parameters or parameter groups.lr(float, default 1e-3): Learning rate.betas(tuple, default (0.9, 0.999)): Adam beta coefficients.eps(float, default 1e-8): Numerical stability.clip_norm(float, default 10.0): Gradient clipping norm.lrd(float, default 1.0): Learning rate decay per step.subsample_aware(bool, default False): Whether to track per-element update counts.
Methods:
step(closure=None): Main optimization step. Dispatches to_step_paramor_step_param_subsampledepending on subsample attributes._step_param(group, p): Standard parameter update with optional DCT transform._step_param_subsample(group, p, subsample): Subsample-aware update using masked scatter/select operations.
Helper Functions
_transform_forward(x, dim, duration): Forward DCT transform, padding to next fast FFT length._transform_inverse(x, dim, duration): Inverse DCT transform, adding time-domain and inverse-DCT components._get_mask(x, indices): Creates a boolean mask from subsample indices.
I/O Contract
| Method | Input | Output |
|---|---|---|
__init__ |
Parameters, lr, betas, eps, clip_norm, lrd, subsample_aware |
DCTAdam instance
|
step(closure) |
Optional closure | Optional loss value |
Usage Examples
import torch
from pyro.optim.dct_adam import DCTAdam
# Time series parameter with DCT dimension
param = torch.nn.Parameter(torch.randn(100, 3))
param._pyro_dct_dim = -2 # time dimension is the second-to-last
optimizer = DCTAdam([param], lr=0.01, clip_norm=10.0)
for step in range(100):
loss = param.pow(2).sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()
# With subsample-aware updates for SVI
param2 = torch.nn.Parameter(torch.randn(1000, 5))
param2._pyro_dct_dim = -2
param2._pyro_subsample = {-2: torch.tensor([0, 10, 20])}
optimizer2 = DCTAdam([param2], lr=0.01, subsample_aware=True)
Related Pages
- Pyro_ppl_Pyro_TensorUtils -- Provides
dct,idct,next_fast_len - Pyro_ppl_Pyro_AdagradRMSProp -- Alternative ADVI optimizer
- Pyro_ppl_Pyro_PyroLRScheduler -- Learning rate scheduling wrapper