Implementation:Pyro ppl Pyro CEVAE Synthetic
Appearance
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/contrib/cevae/synthetic.py
|
| Module | pyro.contrib.cevae |
| Pyro Features | pyro.contrib.cevae.CEVAE, causal inference, treatment effect estimation, variational autoencoder
|
| Paper | Louizos et al. (2017), "Causal Effect Inference with Deep Latent-Variable Models" (NeurIPS 2017) |
Overview
This file demonstrates the Causal Effect Variational Autoencoder (CEVAE) for estimating individual and average treatment effects from observational data. It uses a synthetic data generation process where a binary latent confounder z affects both treatment assignment t and outcome y, along with observed features x.
The CEVAE model learns to:
- Infer the latent confounder
zfrom observed featuresx - Estimate treatment effects by predicting outcomes under both treatment and control conditions
- Compute Individual Treatment Effects (ITE) and Average Treatment Effects (ATE)
The synthetic data generation process is: z ~ Bernoulli(0.5), x ~ Normal(z, 5z + 3(1-z)), t ~ Bernoulli(0.75z + 0.25(1-z)), y ~ Bernoulli(logits=3(z + 2(2t - 2))).
Code Reference
def generate_data(args):
z = dist.Bernoulli(0.5).sample([args.num_data])
x = dist.Normal(z, 5 * z + 3 * (1 - z)).sample([args.feature_dim]).t()
t = dist.Bernoulli(0.75 * z + 0.25 * (1 - z)).sample()
y = dist.Bernoulli(logits=3 * (z + 2 * (2 * t - 2))).sample()
# Compute true ITE
t0_t1 = torch.tensor([[0.0], [1.0]])
y_t0, y_t1 = dist.Bernoulli(logits=3 * (z + 2 * (2 * t0_t1 - 2))).mean
true_ite = y_t1 - y_t0
return x, t, y, true_ite
def main(args):
x_train, t_train, y_train, _ = generate_data(args)
cevae = CEVAE(feature_dim=args.feature_dim, latent_dim=args.latent_dim,
hidden_dim=args.hidden_dim, num_layers=args.num_layers, num_samples=10)
cevae.fit(x_train, t_train, y_train, num_epochs=args.num_epochs,
batch_size=args.batch_size, learning_rate=args.learning_rate)
# Evaluate
x_test, t_test, y_test, true_ite = generate_data(args)
est_ite = cevae.ite(x_test)
est_ate = est_ite.mean()
I/O Contract
| Parameter | Type | Description |
|---|---|---|
--num-data |
int |
Number of data points (default: 1000) |
--feature-dim |
int |
Feature dimensionality (default: 5) |
--latent-dim |
int |
Latent space dimensionality (default: 20) |
--hidden-dim |
int |
Hidden layer size (default: 200) |
--num-layers |
int |
Number of hidden layers (default: 3) |
-n / --num-epochs |
int |
Training epochs (default: 50) |
--jit |
flag | Convert to script module for evaluation |
Output:
- True ATE (from data generation process)
- Naive ATE (unadjusted difference in means)
- Estimated ATE (from CEVAE)
Usage Examples
from pyro.contrib.cevae import CEVAE
# Create and train CEVAE
cevae = CEVAE(feature_dim=5, latent_dim=20, hidden_dim=200, num_layers=3, num_samples=10)
cevae.fit(x_train, t_train, y_train, num_epochs=50, batch_size=100, learning_rate=1e-3)
# Estimate treatment effects
est_ite = cevae.ite(x_test) # Individual Treatment Effects
est_ate = est_ite.mean() # Average Treatment Effect
# Optionally JIT compile for faster evaluation
cevae_jit = cevae.to_script_module()
est_ite = cevae_jit.ite(x_test)
Related Pages
- Pyro_ppl_Pyro_SparseGammaDEF - Another deep generative model example
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment