Implementation:Online ml River Bandit Policy Base
| Knowledge Sources | |
|---|---|
| Domains | Online_Learning, Multi_Armed_Bandits, Reinforcement_Learning |
| Last Updated | 2026-02-08 16:00 GMT |
Overview
Base classes for bandit policies that provide abstract interfaces for both standard and contextual multi-armed bandit problems.
Description
The module defines two abstract base classes: Policy and ContextualPolicy. Policy implements the core multi-armed bandit interface for selecting arms, updating rewards, and managing burn-in phases. Each policy maintains reward statistics for each arm and provides a ranking mechanism. ContextualPolicy extends this to support contextual bandits where arm selection depends on context features. Both classes support reward scaling and flexible reward objects (metrics, statistics, or distributions).
Usage
Use these base classes when implementing custom bandit algorithms. Subclass Policy for standard bandits or ContextualPolicy when decisions depend on context. These classes handle common functionality like burn-in phases, reward tracking, and arm counting automatically.
Code Reference
Source Location
- Repository: Online_ml_River
- File: river/bandit/base.py
Signature
class Policy(base.Base, abc.ABC):
def __init__(
self,
reward_obj: RewardObj | None = None,
reward_scaler: compose.TargetTransformRegressor | None = None,
burn_in=0,
):
...
class ContextualPolicy(Policy):
def pull(self, arm_ids: list[ArmID], context: dict | None = None) -> ArmID:
...
def update(self, arm_id, context, *reward_args, **reward_kwargs):
...
Import
from river import bandit
I/O Contract
| Method | Input | Output |
|---|---|---|
| pull(arm_ids) | List of available arm IDs | Selected arm ID |
| update(arm_id, *reward_args) | Arm ID and reward arguments | None (updates internal state) |
| ranking | None | List of arms sorted by performance |
Usage Examples
from river import bandit
import random
# Subclass Policy to create a custom bandit
class MyBandit(bandit.base.Policy):
def _pull(self, arm_ids):
return random.choice(arm_ids)
# Initialize and use
policy = MyBandit(burn_in=10)
arms = ['A', 'B', 'C']
# Pull an arm
arm = policy.pull(arms)
# Update with reward
policy.update(arm, 1.0)
# Get ranking
print(policy.ranking)