Implementation:Online ml River Bandit Envs CandyCane
| Knowledge Sources | |
|---|---|
| Domains | Online_Learning, Multi_Armed_Bandits, Reinforcement_Learning, Simulation |
| Last Updated | 2026-02-08 16:00 GMT |
Overview
A Gymnasium environment based on the Kaggle Candy Cane Contest featuring 100 vending machines with decaying reward probabilities.
Description
CandyCaneContest is a multi-armed bandit environment with 100 arms (vending machines). Each machine has a threshold that determines the probability of receiving a reward. The unique characteristic is that each machine's threshold decays multiplicatively after each pull, making this a non-stationary bandit problem. The environment runs for 2000 steps by default and returns binary rewards. This models a real-world scenario where resources deplete over time.
Usage
Use this environment to evaluate bandit algorithms on non-stationary problems where reward distributions change over time. It's particularly useful for testing algorithms that can adapt to changing reward structures, such as those with decay or forgetting mechanisms.
Code Reference
Source Location
- Repository: Online_ml_River
- File: river/bandit/envs/candy_cane.py
Signature
class CandyCaneContest(gym.Env):
n_steps = 2000
def __init__(self, n_machines=100, reward_decay=0.03):
...
def reset(self, seed=None, options=None):
...
def step(self, machine_index):
...
Import
import gymnasium as gym
env = gym.make('river_bandits/CandyCaneContest-v0')
I/O Contract
| Parameter/Method | Type | Description |
|---|---|---|
| n_machines | int (default: 100) | Number of vending machines (arms) |
| reward_decay | float (default: 0.03) | Multiplicative decay rate per pull |
| action_space | Discrete(100) | Action space (machine indices) |
| reward_range | (0.0, 1.0) | Range of possible rewards |
Usage Examples
import gymnasium as gym
from river import bandit
from river import stats
env = gym.make('river_bandits/CandyCaneContest-v0')
_ = env.reset(seed=42)
_ = env.action_space.seed(123)
policy = bandit.EpsilonGreedy(epsilon=0.1, seed=101)
metric = stats.Sum()
while True:
arm = policy.pull(range(env.action_space.n))
observation, reward, terminated, truncated, info = env.step(arm)
policy.update(arm, reward)
metric.update(reward)
if terminated or truncated:
break
print(metric) # Total reward collected