Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bitsandbytes foundation Bitsandbytes Optimizer8bit Step

From Leeroopedia


Sources Repo: bitsandbytes, Paper: 8-bit Optimizers via Block-wise Quantization
Domains Optimization, Memory_Management

Overview

Concrete tool for performing optimization steps with 8-bit quantized states provided by the bitsandbytes library. Optimizer8bit is the base class for all bitsandbytes optimizers, implementing the core step loop, state dictionary management for FSDP compatibility, and integration with the GlobalOptimManager for per-parameter configuration overrides.

Description

Optimizer8bit is the base class for all bitsandbytes optimizers. Its step() method orchestrates the optimization pipeline:

  1. Iterates over all parameter groups and their parameters.
  2. For each parameter with a gradient, lazily initializes optimizer state on the first step via init_state().
  3. Calls update_step() per parameter to perform the actual optimization.

Two concrete subclasses implement the state initialization and update logic:

  • Optimizer2State (for Adam-family optimizers): Manages two state tensors (momentum and variance). On update_step: (1) dequantizes 8-bit states to FP32, (2) calls the appropriate CUDA kernel (optimizer_update_32bit, optimizer_update_8bit, or optimizer_update_8bit_blockwise), (3) re-quantizes states back to 8-bit.
  • Optimizer1State (for SGD-family optimizers): Manages a single state tensor (momentum). Same dequantize-update-requantize pattern.

Paged memory support: When is_paged=True, optimizer state tensors are allocated as paged memory that can be offloaded from GPU to CPU on OOM conditions. The GlobalPageManager coordinates paged tensor allocation and prefetching.

FSDP compatibility: The custom state_dict() method wraps quantization-specific tensors (state1, state2, absmax1, absmax2, qmap1, qmap2, etc.) in a nested dictionary keyed by __bnb_optimizer_quant_state__. This prevents FSDP's full_optim_state_dict from attempting to gather these tensors across ranks (which would fail due to shape mismatches). The load_state_dict() method reverses this wrapping.

Code Reference

Source Files Lines
bitsandbytes repo bitsandbytes/optim/optimizer.py Optimizer8bit L113-335, Optimizer2State L384-625, Optimizer1State L628-841

Class signature:

class Optimizer8bit(torch.optim.Optimizer):
    def __init__(self, params, defaults, optim_bits=32, is_paged=False):
        ...

    @torch.no_grad()
    def step(self, closure=None):
        ...

    def state_dict(self):
        ...

    def load_state_dict(self, state_dict, move_to_device=True):
        ...

Import:

from bitsandbytes.optim.optimizer import Optimizer8bit

I/O Contract

Inputs

Parameter Type Required Description
params iterable Yes Model parameters to optimize (iterable of torch.Tensor or dicts defining parameter groups)
closure Callable No A closure that reevaluates the model and returns the loss
optim_bits int No (default: 32) Number of bits for optimizer state (8 or 32)
is_paged bool No (default: False) Whether to use paged memory for GPU-to-CPU offload on OOM

Outputs

Output Type Description
loss Optional[torch.Tensor] Loss value if a closure was provided, otherwise None
parameter updates in-place Parameters are updated in-place via the optimizer step
optimizer states stored internally States stored as quantized 8-bit tensors (state1, state2) with absmax scaling factors and qmap codebooks

Usage Examples

import torch
import bitsandbytes as bnb

# Model setup
model = MyModel().cuda()

# Create an 8-bit optimizer (e.g., Adam8bit which inherits from Optimizer8bit)
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=1e-3)

# Standard training loop -- drop-in replacement
for batch in dataloader:
    optimizer.zero_grad()
    loss = model(batch)
    loss.backward()
    optimizer.step()   # Calls Optimizer8bit.step() internally

Saving and loading state (FSDP-compatible):

# Save
state = optimizer.state_dict()
torch.save(state, "optimizer_state.pt")

# Load
state = torch.load("optimizer_state.pt")
optimizer.load_state_dict(state)

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment