Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro DCTAdam

From Leeroopedia


Property Value
Module pyro.optim.dct_adam
Source pyro/optim/dct_adam.py
Lines 215
Classes DCTAdam
Parent Class torch.optim.optimizer.Optimizer
Status EXPERIMENTAL
Dependencies torch, pyro.ops.tensor_utils (dct, idct, next_fast_len)

Overview

DCTAdam is an experimental optimizer that augments Adam with Discrete Cosine Transform (DCT) processing for time-series parameters. For most parameters, it behaves like ClippedAdam. However, when a parameter has a ._pyro_dct_dim attribute indicating a time dimension, DCTAdam creates a secondary optimization in the frequency domain.

The algorithm:

  1. Forward-transforms the gradient to the frequency domain using DCT.
  2. Concatenates time-domain and frequency-domain representations.
  3. Applies standard Adam updates in this augmented space.
  4. Inverse-transforms back to the time domain.

This dual-domain approach is particularly useful for time series models where parameters exhibit both local and global temporal structure.

The optimizer also supports subsample-aware updates: when a parameter has a ._pyro_subsample attribute, only the subsampled entries are updated, improving efficiency for large datasets.

Code Reference

Class: DCTAdam

Constructor:

  • params: Iterable of parameters or parameter groups.
  • lr (float, default 1e-3): Learning rate.
  • betas (tuple, default (0.9, 0.999)): Adam beta coefficients.
  • eps (float, default 1e-8): Numerical stability.
  • clip_norm (float, default 10.0): Gradient clipping norm.
  • lrd (float, default 1.0): Learning rate decay per step.
  • subsample_aware (bool, default False): Whether to track per-element update counts.

Methods:

  • step(closure=None): Main optimization step. Dispatches to _step_param or _step_param_subsample depending on subsample attributes.
  • _step_param(group, p): Standard parameter update with optional DCT transform.
  • _step_param_subsample(group, p, subsample): Subsample-aware update using masked scatter/select operations.

Helper Functions

  • _transform_forward(x, dim, duration): Forward DCT transform, padding to next fast FFT length.
  • _transform_inverse(x, dim, duration): Inverse DCT transform, adding time-domain and inverse-DCT components.
  • _get_mask(x, indices): Creates a boolean mask from subsample indices.

I/O Contract

Method Input Output
__init__ Parameters, lr, betas, eps, clip_norm, lrd, subsample_aware DCTAdam instance
step(closure) Optional closure Optional loss value

Usage Examples

import torch
from pyro.optim.dct_adam import DCTAdam

# Time series parameter with DCT dimension
param = torch.nn.Parameter(torch.randn(100, 3))
param._pyro_dct_dim = -2  # time dimension is the second-to-last

optimizer = DCTAdam([param], lr=0.01, clip_norm=10.0)

for step in range(100):
    loss = param.pow(2).sum()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# With subsample-aware updates for SVI
param2 = torch.nn.Parameter(torch.randn(1000, 5))
param2._pyro_dct_dim = -2
param2._pyro_subsample = {-2: torch.tensor([0, 10, 20])}

optimizer2 = DCTAdam([param2], lr=0.01, subsample_aware=True)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment