Implementation:Bitsandbytes foundation Bitsandbytes GlobalOptimManager
| Sources | Repo: bitsandbytes |
|---|---|
| Domains | Optimization, Configuration |
Overview
Concrete tool for managing per-parameter optimizer configuration overrides provided by the bitsandbytes library. GlobalOptimManager is a singleton class that stores and applies per-parameter optimizer hyperparameter overrides, enabling mixed-precision optimization strategies within a single optimizer instance.
Description
GlobalOptimManager is a singleton that stores per-parameter optimizer configuration overrides. It acts as a global registry that the bitsandbytes optimizers consult during initialization to determine the final configuration for each parameter.
Usage workflow:
get_instance(): Obtain the singleton instance. The constructor is disabled (raisesRuntimeError); onlyget_instance()can create or return the singleton.register_parameters(params): Register model parameters. This maps parameter identities (id(tensor)) to(group_index, param_index)pairs. Must be called before creating the optimizer.override_config(parameters, key, value): Set per-parameter overrides for any optimizer hyperparameter. Supports single key-value pairs or dictionaries of multiple overrides.
Internal state:
pid2config: Dictionary mappingid(parameter)to configuration override dictionaries.index2config: Dictionary mapping(group_index, param_index)tuples to configuration override dictionaries. Populated duringregister_parameters().uses_config_override: Boolean flag indicating whether any overrides have been registered.module_weight_config_triple: List of(module, param_name, config)tuples for module-level overrides, used byregister_module_override().
When an optimizer calls get_config(gindex, pindex, group), it checks index2config for any overrides matching the (gindex, pindex) pair and merges them into the default configuration.
Code Reference
| Source | File | Lines |
|---|---|---|
| bitsandbytes repo | bitsandbytes/optim/optimizer.py |
GlobalOptimManager L22-110 |
Class signature:
class GlobalOptimManager:
@classmethod
def get_instance(cls) -> "GlobalOptimManager":
...
def register_parameters(self, params) -> None:
...
def override_config(
self, parameters, key=None, value=None, key_value_dict=None
) -> None:
...
Import:
import bitsandbytes as bnb
mng = bnb.optim.GlobalOptimManager.get_instance()
I/O Contract
register_parameters
| Parameter | Type | Required | Description |
|---|---|---|---|
params |
torch.Tensor or list of torch.Tensor | Yes | Model parameters to register. Can also be parameter groups (list of dicts with "params" key).
|
override_config
| Parameter | Type | Required | Description |
|---|---|---|---|
parameters |
torch.Tensor, torch.nn.Parameter, or list | Yes | The parameter(s) to override configuration for |
key |
str | No | The hyperparameter name to override (e.g., "optim_bits", "lr")
|
value |
Any | No | The value for the hyperparameter specified by key
|
key_value_dict |
dict | No | A dictionary of multiple key-value pairs to override simultaneously |
Note: Either key/value or key_value_dict must be provided, not both.
Output
| Output | Description |
|---|---|
| Stored config overrides | Configuration overrides are stored internally and applied when the optimizer calls get_config() during optimization steps
|
Usage Examples
Override embedding layer to use 32-bit while rest uses 8-bit:
import torch
import bitsandbytes as bnb
mng = bnb.optim.GlobalOptimManager.get_instance()
model = MyModel()
mng.register_parameters(model.parameters()) # Register while on CPU
# Override: embedding uses 32-bit precision
mng.override_config(model.embedding.weight, "optim_bits", 32)
model = model.cuda()
# All other parameters use 8-bit Adam; embedding uses 32-bit
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=1e-3)
Override multiple hyperparameters at once:
mng.override_config(
model.lm_head.weight,
key_value_dict={
"optim_bits": 32,
"lr": 5e-4,
"percentile_clipping": 5,
},
)
Override multiple parameters:
# Pass a list of parameters
sensitive_params = [model.embedding.weight, model.lm_head.weight]
mng.override_config(sensitive_params, "optim_bits", 32)