Implementation:Pyro ppl Pyro SeedMessenger
| Attribute | Value |
|---|---|
| File | pyro/poutine/seed_messenger.py
|
| Module | pyro.poutine.seed_messenger
|
| Lines | 39 |
| Parent Class | Messenger
|
| Purpose | Set the random number generator seed for reproducible execution |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
SeedMessenger sets the random number generator state to a pre-defined seed before executing the wrapped function, and restores the original state afterward. This is equivalent to calling pyro.set_rng_seed() before the function call.
The handler saves the current RNG state on __enter__, sets the seed, and restores the saved state on __exit__. This ensures reproducibility of the enclosed computation without affecting the global RNG state after the handler exits.
Unlike other messengers, SeedMessenger does not intercept individual pyro.sample or pyro.param calls. Instead, it operates purely at the context manager level by manipulating the global RNG state. Note that __enter__ does not call super().__enter__() and therefore does not push itself onto _PYRO_STACK.
Code Reference
class SeedMessenger(Messenger):
def __init__(self, rng_seed: int) -> None:
assert isinstance(rng_seed, int)
self.rng_seed = rng_seed
super().__init__()
def __enter__(self) -> None:
self.old_state = get_rng_state()
set_rng_seed(self.rng_seed)
def __exit__(self, exc_type, exc_value, traceback) -> None:
set_rng_state(self.old_state)
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| rng_seed | int
|
The random number generator seed to set |
| Effect | Description |
|---|---|
| On enter | Saves current RNG state, then sets the RNG seed to rng_seed
|
| On exit | Restores the previously saved RNG state |
Note: This handler does not push itself onto _PYRO_STACK. It only affects the global RNG state.
Usage Examples
Reproducible Model Execution
def model():
return pyro.sample("z", dist.Normal(0, 1))
seeded_model = pyro.poutine.seed(model, rng_seed=42)
result1 = seeded_model()
result2 = seeded_model()
assert result1 == result2 # Same seed produces same result
Using as Context Manager
with pyro.poutine.seed(rng_seed=123):
z1 = pyro.sample("z1", dist.Normal(0, 1))
z2 = pyro.sample("z2", dist.Normal(0, 1))
# RNG state is restored here
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.seed()factory function