Implementation:Online ml River Bandit BayesUCB
| Knowledge Sources | |
|---|---|
| Domains | Online_Learning, Multi_Armed_Bandits, Bayesian_Methods |
| Last Updated | 2026-02-08 16:00 GMT |
Overview
A Bayesian approach to the multi-armed bandit problem that uses posterior distributions to compute upper confidence bounds for arm selection.
Description
BayesUCB is a Bayesian algorithm that maintains a Beta posterior distribution for each arm. At each time step, it computes an upper confidence bound (UCB) by calculating the p-th quantile of the Beta distribution, where p = 1 - 1/(n+1). The arm with the highest UCB is selected. After observing a reward, the posterior distribution is updated. This approach provides a principled way to balance exploration and exploitation using Bayesian inference.
Usage
Use BayesUCB when you want a Bayesian approach to bandits with theoretical guarantees. It works well with binary rewards and provides good performance in practice. The algorithm requires scipy for computing the inverse incomplete beta function.
Code Reference
Source Location
- Repository: Online_ml_River
- File: river/bandit/bayes_ucb.py
Signature
class BayesUCB(bandit.base.Policy):
def __init__(self, reward_obj=None, burn_in=0, seed: int | None = None):
...
Import
from river import bandit
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| reward_obj | RewardObj (optional) | Posterior distribution (defaults to proba.Beta()) |
| burn_in | int | Initial observations per arm before using posterior |
| seed | int (optional) | Random seed for reproducibility |
Usage Examples
import gymnasium as gym
from river import bandit
from river import stats
env = gym.make('river_bandits/CandyCaneContest-v0')
_ = env.reset(seed=42)
_ = env.action_space.seed(123)
policy = bandit.BayesUCB(seed=123)
metric = stats.Sum()
while True:
action = policy.pull(range(env.action_space.n))
observation, reward, terminated, truncated, info = env.step(action)
policy.update(action, reward)
metric.update(reward)
if terminated or truncated:
break
print(metric) # Sum: 841.