Implementation:Pyro ppl Pyro LDA
Appearance
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/lda.py
|
| Module | examples |
| Pyro Features | TraceEnum_ELBO, JitTraceEnum_ELBO, pyro.plate, config_enumerate, ClippedAdam, amortized inference with neural network guide, discrete variable marginalization
|
| Papers | Srivastava & Sutton (2017), "Autoencoding Variational Inference for Topic Models"; Jankowiak & Obermeyer (2018), "Pathwise gradients beyond the reparameterization trick" |
Overview
This file implements amortized Latent Dirichlet Allocation (LDA) using Pyro's enumeration-based inference. The model treats documents as vectors of categorical word ids and uses TraceEnum_ELBO to exactly marginalize out the discrete word-topic assignments via parallel enumeration.
Key design patterns demonstrated:
- Discrete variable marginalization: The
word_topicsvariable is annotated withinfer={"enumerate": "parallel"}and excluded from the guide, allowing Pyro to sum it out exactly. - Amortized inference: An MLP-based guide predicts document-topic distributions from word count histograms, enabling generalization to new documents.
- Conjugate guide for globals: Topic weights and word distributions use parametric variational families (Gamma, Dirichlet).
- Reparameterized Gamma/Dirichlet: Uses PyTorch's reparameterized distributions for low-variance gradients.
Code Reference
def model(data=None, args=None, batch_size=None):
with pyro.plate("topics", args.num_topics):
topic_weights = pyro.sample("topic_weights",
dist.Gamma(1.0 / args.num_topics, 1.0))
topic_words = pyro.sample("topic_words",
dist.Dirichlet(torch.ones(args.num_words) / args.num_words))
with pyro.plate("documents", args.num_docs) as ind:
if data is not None:
data = data[:, ind]
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
with pyro.plate("words", args.num_words_per_doc):
word_topics = pyro.sample("word_topics",
dist.Categorical(doc_topics),
infer={"enumerate": "parallel"})
data = pyro.sample("doc_words",
dist.Categorical(topic_words[word_topics]), obs=data)
return topic_weights, topic_words, data
I/O Contract
| Parameter | Type | Description |
|---|---|---|
-t / --num-topics |
int |
Number of topics (default: 8) |
-w / --num-words |
int |
Vocabulary size (default: 1024) |
-d / --num-docs |
int |
Number of documents (default: 1000) |
-wd / --num-words-per-doc |
int |
Words per document (default: 64) |
-n / --num-steps |
int |
SVI steps (default: 1000) |
-b / --batch-size |
int |
Document batch size (default: 32) |
--jit |
flag | Use JIT-compiled ELBO |
Output:
- Training ELBO loss per step
- Learned topic weights and word distributions
Usage Examples
# Train LDA with 8 topics on synthetic data
# python lda.py -t 8 -w 1024 -d 1000 -n 1000 -b 32
# With JIT compilation
# python lda.py --jit -t 16 -n 2000
Related Pages
- Pyro_ppl_Pyro_Funsor_HMM - Another example using parallel enumeration
- Pyro_ppl_Pyro_SparseGammaDEF - Deep generative model for images
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment