Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro CEVAE Synthetic

From Leeroopedia


Property Value
Implementation Type Pattern Doc
Source File examples/contrib/cevae/synthetic.py
Module pyro.contrib.cevae
Pyro Features pyro.contrib.cevae.CEVAE, causal inference, treatment effect estimation, variational autoencoder
Paper Louizos et al. (2017), "Causal Effect Inference with Deep Latent-Variable Models" (NeurIPS 2017)

Overview

This file demonstrates the Causal Effect Variational Autoencoder (CEVAE) for estimating individual and average treatment effects from observational data. It uses a synthetic data generation process where a binary latent confounder z affects both treatment assignment t and outcome y, along with observed features x.

The CEVAE model learns to:

  • Infer the latent confounder z from observed features x
  • Estimate treatment effects by predicting outcomes under both treatment and control conditions
  • Compute Individual Treatment Effects (ITE) and Average Treatment Effects (ATE)

The synthetic data generation process is: z ~ Bernoulli(0.5), x ~ Normal(z, 5z + 3(1-z)), t ~ Bernoulli(0.75z + 0.25(1-z)), y ~ Bernoulli(logits=3(z + 2(2t - 2))).

Code Reference

def generate_data(args):
    z = dist.Bernoulli(0.5).sample([args.num_data])
    x = dist.Normal(z, 5 * z + 3 * (1 - z)).sample([args.feature_dim]).t()
    t = dist.Bernoulli(0.75 * z + 0.25 * (1 - z)).sample()
    y = dist.Bernoulli(logits=3 * (z + 2 * (2 * t - 2))).sample()
    # Compute true ITE
    t0_t1 = torch.tensor([[0.0], [1.0]])
    y_t0, y_t1 = dist.Bernoulli(logits=3 * (z + 2 * (2 * t0_t1 - 2))).mean
    true_ite = y_t1 - y_t0
    return x, t, y, true_ite

def main(args):
    x_train, t_train, y_train, _ = generate_data(args)
    cevae = CEVAE(feature_dim=args.feature_dim, latent_dim=args.latent_dim,
                   hidden_dim=args.hidden_dim, num_layers=args.num_layers, num_samples=10)
    cevae.fit(x_train, t_train, y_train, num_epochs=args.num_epochs,
              batch_size=args.batch_size, learning_rate=args.learning_rate)
    # Evaluate
    x_test, t_test, y_test, true_ite = generate_data(args)
    est_ite = cevae.ite(x_test)
    est_ate = est_ite.mean()

I/O Contract

Parameter Type Description
--num-data int Number of data points (default: 1000)
--feature-dim int Feature dimensionality (default: 5)
--latent-dim int Latent space dimensionality (default: 20)
--hidden-dim int Hidden layer size (default: 200)
--num-layers int Number of hidden layers (default: 3)
-n / --num-epochs int Training epochs (default: 50)
--jit flag Convert to script module for evaluation

Output:

  • True ATE (from data generation process)
  • Naive ATE (unadjusted difference in means)
  • Estimated ATE (from CEVAE)

Usage Examples

from pyro.contrib.cevae import CEVAE

# Create and train CEVAE
cevae = CEVAE(feature_dim=5, latent_dim=20, hidden_dim=200, num_layers=3, num_samples=10)
cevae.fit(x_train, t_train, y_train, num_epochs=50, batch_size=100, learning_rate=1e-3)

# Estimate treatment effects
est_ite = cevae.ite(x_test)  # Individual Treatment Effects
est_ate = est_ite.mean()       # Average Treatment Effect

# Optionally JIT compile for faster evaluation
cevae_jit = cevae.to_script_module()
est_ite = cevae_jit.ite(x_test)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment