Implementation:Bitsandbytes foundation Bitsandbytes Optimizer8bit Step
| Sources | Repo: bitsandbytes, Paper: 8-bit Optimizers via Block-wise Quantization |
|---|---|
| Domains | Optimization, Memory_Management |
Overview
Concrete tool for performing optimization steps with 8-bit quantized states provided by the bitsandbytes library. Optimizer8bit is the base class for all bitsandbytes optimizers, implementing the core step loop, state dictionary management for FSDP compatibility, and integration with the GlobalOptimManager for per-parameter configuration overrides.
Description
Optimizer8bit is the base class for all bitsandbytes optimizers. Its step() method orchestrates the optimization pipeline:
- Iterates over all parameter groups and their parameters.
- For each parameter with a gradient, lazily initializes optimizer state on the first step via
init_state(). - Calls
update_step()per parameter to perform the actual optimization.
Two concrete subclasses implement the state initialization and update logic:
Optimizer2State(for Adam-family optimizers): Manages two state tensors (momentum and variance). Onupdate_step: (1) dequantizes 8-bit states to FP32, (2) calls the appropriate CUDA kernel (optimizer_update_32bit,optimizer_update_8bit, oroptimizer_update_8bit_blockwise), (3) re-quantizes states back to 8-bit.Optimizer1State(for SGD-family optimizers): Manages a single state tensor (momentum). Same dequantize-update-requantize pattern.
Paged memory support: When is_paged=True, optimizer state tensors are allocated as paged memory that can be offloaded from GPU to CPU on OOM conditions. The GlobalPageManager coordinates paged tensor allocation and prefetching.
FSDP compatibility: The custom state_dict() method wraps quantization-specific tensors (state1, state2, absmax1, absmax2, qmap1, qmap2, etc.) in a nested dictionary keyed by __bnb_optimizer_quant_state__. This prevents FSDP's full_optim_state_dict from attempting to gather these tensors across ranks (which would fail due to shape mismatches). The load_state_dict() method reverses this wrapping.
Code Reference
| Source | Files | Lines |
|---|---|---|
| bitsandbytes repo | bitsandbytes/optim/optimizer.py |
Optimizer8bit L113-335, Optimizer2State L384-625, Optimizer1State L628-841 |
Class signature:
class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32, is_paged=False):
...
@torch.no_grad()
def step(self, closure=None):
...
def state_dict(self):
...
def load_state_dict(self, state_dict, move_to_device=True):
...
Import:
from bitsandbytes.optim.optimizer import Optimizer8bit
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
params |
iterable | Yes | Model parameters to optimize (iterable of torch.Tensor or dicts defining parameter groups)
|
closure |
Callable | No | A closure that reevaluates the model and returns the loss |
optim_bits |
int | No (default: 32) | Number of bits for optimizer state (8 or 32) |
is_paged |
bool | No (default: False) | Whether to use paged memory for GPU-to-CPU offload on OOM |
Outputs
| Output | Type | Description |
|---|---|---|
| loss | Optional[torch.Tensor] | Loss value if a closure was provided, otherwise None |
| parameter updates | in-place | Parameters are updated in-place via the optimizer step |
| optimizer states | stored internally | States stored as quantized 8-bit tensors (state1, state2) with absmax scaling factors and qmap codebooks
|
Usage Examples
import torch
import bitsandbytes as bnb
# Model setup
model = MyModel().cuda()
# Create an 8-bit optimizer (e.g., Adam8bit which inherits from Optimizer8bit)
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=1e-3)
# Standard training loop -- drop-in replacement
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch)
loss.backward()
optimizer.step() # Calls Optimizer8bit.step() internally
Saving and loading state (FSDP-compatible):
# Save
state = optimizer.state_dict()
torch.save(state, "optimizer_state.pt")
# Load
state = torch.load("optimizer_state.pt")
optimizer.load_state_dict(state)