Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro LDA

From Leeroopedia


Property Value
Implementation Type Pattern Doc
Source File examples/lda.py
Module examples
Pyro Features TraceEnum_ELBO, JitTraceEnum_ELBO, pyro.plate, config_enumerate, ClippedAdam, amortized inference with neural network guide, discrete variable marginalization
Papers Srivastava & Sutton (2017), "Autoencoding Variational Inference for Topic Models"; Jankowiak & Obermeyer (2018), "Pathwise gradients beyond the reparameterization trick"

Overview

This file implements amortized Latent Dirichlet Allocation (LDA) using Pyro's enumeration-based inference. The model treats documents as vectors of categorical word ids and uses TraceEnum_ELBO to exactly marginalize out the discrete word-topic assignments via parallel enumeration.

Key design patterns demonstrated:

  • Discrete variable marginalization: The word_topics variable is annotated with infer={"enumerate": "parallel"} and excluded from the guide, allowing Pyro to sum it out exactly.
  • Amortized inference: An MLP-based guide predicts document-topic distributions from word count histograms, enabling generalization to new documents.
  • Conjugate guide for globals: Topic weights and word distributions use parametric variational families (Gamma, Dirichlet).
  • Reparameterized Gamma/Dirichlet: Uses PyTorch's reparameterized distributions for low-variance gradients.

Code Reference

def model(data=None, args=None, batch_size=None):
    with pyro.plate("topics", args.num_topics):
        topic_weights = pyro.sample("topic_weights",
            dist.Gamma(1.0 / args.num_topics, 1.0))
        topic_words = pyro.sample("topic_words",
            dist.Dirichlet(torch.ones(args.num_words) / args.num_words))
    with pyro.plate("documents", args.num_docs) as ind:
        if data is not None:
            data = data[:, ind]
        doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
        with pyro.plate("words", args.num_words_per_doc):
            word_topics = pyro.sample("word_topics",
                dist.Categorical(doc_topics),
                infer={"enumerate": "parallel"})
            data = pyro.sample("doc_words",
                dist.Categorical(topic_words[word_topics]), obs=data)
    return topic_weights, topic_words, data

I/O Contract

Parameter Type Description
-t / --num-topics int Number of topics (default: 8)
-w / --num-words int Vocabulary size (default: 1024)
-d / --num-docs int Number of documents (default: 1000)
-wd / --num-words-per-doc int Words per document (default: 64)
-n / --num-steps int SVI steps (default: 1000)
-b / --batch-size int Document batch size (default: 32)
--jit flag Use JIT-compiled ELBO

Output:

  • Training ELBO loss per step
  • Learned topic weights and word distributions

Usage Examples

# Train LDA with 8 topics on synthetic data
# python lda.py -t 8 -w 1024 -d 1000 -n 1000 -b 32

# With JIT compilation
# python lda.py --jit -t 16 -n 2000

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment