Implementation:Pyro ppl Pyro SparseGammaDEF
Appearance
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/sparse_gamma_def.py
|
| Module | examples |
| Pyro Features | pyro.plate, pyro.sample, pyro.param, TraceMeanField_ELBO, AutoDiagonalNormal, EasyGuide, custom guide, Gamma/Poisson distributions, three guide strategies
|
| Paper | Ranganath et al. (2015), "Deep Exponential Families" |
| Dataset | Olivetti Faces (64x64 grayscale images) |
Overview
This file implements the Sparse Gamma Deep Exponential Family (DEF) model, a deep generative model with three layers of Gamma-distributed latent variables connected by Gamma-distributed weight matrices. The model uses a Poisson likelihood for image data.
Three guide strategies are demonstrated and compared:
- Custom guide: A hand-designed mean-field Gamma variational family with explicit parameter clipping to avoid numerical issues.
- Auto guide:
AutoDiagonalNormalthat automatically constructs a diagonal Normal variational distribution (in unconstrained space). - Easy guide:
EasyGuidethat groups latent variables and uses Normal variational families, providing a middle ground between custom and auto guides. This is the best-performing option.
The model architecture:
- Top layer: z_top (100 units) with Gamma prior
- Mid layer: z_mid (40 units) with rate modulated by z_top * w_top
- Bottom layer: z_bottom (15 units) with rate modulated by z_mid * w_mid
- Observation: Poisson(z_bottom * w_bottom) for 64x64 = 4096 pixel values
Code Reference
class SparseGammaDEF:
def model(self, x):
x_size = x.size(0)
with pyro.plate("w_top_plate", self.top_width * self.mid_width):
w_top = pyro.sample("w_top", Gamma(self.alpha_w, self.beta_w))
# ... similar for w_mid, w_bottom
with pyro.plate("data", x_size):
z_top = pyro.sample("z_top",
Gamma(self.alpha_z, self.beta_z).expand([self.top_width]).to_event(1))
mean_mid = torch.matmul(z_top, w_top.reshape(...))
z_mid = pyro.sample("z_mid",
Gamma(self.alpha_z, self.beta_z / mean_mid).to_event(1))
mean_bottom = torch.matmul(z_mid, w_mid.reshape(...))
z_bottom = pyro.sample("z_bottom",
Gamma(self.alpha_z, self.beta_z / mean_bottom).to_event(1))
mean_obs = torch.matmul(z_bottom, w_bottom.reshape(...))
pyro.sample("obs", Poisson(mean_obs).to_event(1), obs=x)
class MyEasyGuide(EasyGuide):
def guide(self, x):
global_group = self.group(match="w_.*")
global_group.sample("ws", Normal(global_mean, global_scale).to_event(1))
local_group = self.group(match="z_.*")
with self.plate("data", x.size(0)):
local_group.sample("zs", Normal(local_mean, local_scale).to_event(1))
I/O Contract
| Parameter | Type | Description |
|---|---|---|
-n / --num-epochs |
int |
Training epochs (default: 1500) |
-ef / --eval-frequency |
int |
Evaluation interval in epochs (default: 25) |
-ep / --eval-particles |
int |
Particles for evaluation (default: 20) |
--guide |
str |
Guide type: "custom", "auto", or "easy" |
Output:
- Training ELBO at each evaluation point
Model dimensions:
- Weight matrices: top (100x40), mid (40x15), bottom (15x4096)
- Latent layers: z_top (100), z_mid (40), z_bottom (15)
- Image size: 64x64 = 4096 pixels
Usage Examples
# Train with custom guide (original paper approach)
# python sparse_gamma_def.py --guide custom -n 1500
# Train with EasyGuide (best performance)
# python sparse_gamma_def.py --guide easy -n 1500
# Train with auto guide
# python sparse_gamma_def.py --guide auto -n 1500
Related Pages
- Pyro_ppl_Pyro_LDA - Another deep generative model using plates and enumeration
- Pyro_ppl_Pyro_NeuTra - Uses AutoDiagonalNormal guide
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment