Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro SparseGammaDEF

From Leeroopedia


Property Value
Implementation Type Pattern Doc
Source File examples/sparse_gamma_def.py
Module examples
Pyro Features pyro.plate, pyro.sample, pyro.param, TraceMeanField_ELBO, AutoDiagonalNormal, EasyGuide, custom guide, Gamma/Poisson distributions, three guide strategies
Paper Ranganath et al. (2015), "Deep Exponential Families"
Dataset Olivetti Faces (64x64 grayscale images)

Overview

This file implements the Sparse Gamma Deep Exponential Family (DEF) model, a deep generative model with three layers of Gamma-distributed latent variables connected by Gamma-distributed weight matrices. The model uses a Poisson likelihood for image data.

Three guide strategies are demonstrated and compared:

  • Custom guide: A hand-designed mean-field Gamma variational family with explicit parameter clipping to avoid numerical issues.
  • Auto guide: AutoDiagonalNormal that automatically constructs a diagonal Normal variational distribution (in unconstrained space).
  • Easy guide: EasyGuide that groups latent variables and uses Normal variational families, providing a middle ground between custom and auto guides. This is the best-performing option.

The model architecture:

  • Top layer: z_top (100 units) with Gamma prior
  • Mid layer: z_mid (40 units) with rate modulated by z_top * w_top
  • Bottom layer: z_bottom (15 units) with rate modulated by z_mid * w_mid
  • Observation: Poisson(z_bottom * w_bottom) for 64x64 = 4096 pixel values

Code Reference

class SparseGammaDEF:
    def model(self, x):
        x_size = x.size(0)
        with pyro.plate("w_top_plate", self.top_width * self.mid_width):
            w_top = pyro.sample("w_top", Gamma(self.alpha_w, self.beta_w))
        # ... similar for w_mid, w_bottom

        with pyro.plate("data", x_size):
            z_top = pyro.sample("z_top",
                Gamma(self.alpha_z, self.beta_z).expand([self.top_width]).to_event(1))
            mean_mid = torch.matmul(z_top, w_top.reshape(...))
            z_mid = pyro.sample("z_mid",
                Gamma(self.alpha_z, self.beta_z / mean_mid).to_event(1))
            mean_bottom = torch.matmul(z_mid, w_mid.reshape(...))
            z_bottom = pyro.sample("z_bottom",
                Gamma(self.alpha_z, self.beta_z / mean_bottom).to_event(1))
            mean_obs = torch.matmul(z_bottom, w_bottom.reshape(...))
            pyro.sample("obs", Poisson(mean_obs).to_event(1), obs=x)

class MyEasyGuide(EasyGuide):
    def guide(self, x):
        global_group = self.group(match="w_.*")
        global_group.sample("ws", Normal(global_mean, global_scale).to_event(1))
        local_group = self.group(match="z_.*")
        with self.plate("data", x.size(0)):
            local_group.sample("zs", Normal(local_mean, local_scale).to_event(1))

I/O Contract

Parameter Type Description
-n / --num-epochs int Training epochs (default: 1500)
-ef / --eval-frequency int Evaluation interval in epochs (default: 25)
-ep / --eval-particles int Particles for evaluation (default: 20)
--guide str Guide type: "custom", "auto", or "easy"

Output:

  • Training ELBO at each evaluation point

Model dimensions:

  • Weight matrices: top (100x40), mid (40x15), bottom (15x4096)
  • Latent layers: z_top (100), z_mid (40), z_bottom (15)
  • Image size: 64x64 = 4096 pixels

Usage Examples

# Train with custom guide (original paper approach)
# python sparse_gamma_def.py --guide custom -n 1500

# Train with EasyGuide (best performance)
# python sparse_gamma_def.py --guide easy -n 1500

# Train with auto guide
# python sparse_gamma_def.py --guide auto -n 1500

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment