Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bitsandbytes foundation Bitsandbytes Adam8bit

From Leeroopedia


Sources Repo: bitsandbytes, Paper: 8-bit Optimizers via Block-wise Quantization
Domains Optimization

Overview

Concrete tool for memory-efficient Adam optimization with 8-bit quantized states provided by the bitsandbytes library. Adam8bit is a drop-in replacement for torch.optim.Adam that stores momentum and variance in 8-bit format for approximately 75% optimizer state memory savings.

Description

Adam8bit inherits from Optimizer2State which inherits from Optimizer8bit. The inheritance chain is:

torch.optim.Optimizer
    -> Optimizer8bit        (step loop, state_dict, FSDP compat)
        -> Optimizer2State  (two-state init/update: momentum + variance)
            -> Adam8bit     (hardcodes optim_bits=8, optimizer_name="adam")

Key implementation details:

  • Adam8bit hardcodes optim_bits=8 in its call to Optimizer2State.__init__, regardless of the optim_bits parameter in its own signature (which exists only for API compatibility and raises ValueError if set to anything other than the default 32).
  • The optimizer_name is set to "adam", which selects the Adam kernel in the underlying CUDA optimizer update functions.
  • amsgrad is not supported; passing amsgrad=True raises ValueError.
  • The Optimizer2State class handles two state tensors (momentum as state1 and variance as state2).
  • On the first step, states are initialized to zeros via init_state(). For 8-bit mode, this creates uint8 zero tensors with corresponding quantization maps and absmax scaling factors.
  • On subsequent steps, update_step() dispatches to one of three CUDA kernel paths depending on precision and block mode:
    • 32-bit path: optimizer_update_32bit (for tensors below min_8bit_size)
    • 8-bit non-blockwise path: optimizer_update_8bit
    • 8-bit blockwise path: optimizer_update_8bit_blockwise (default and recommended)

Code Reference

Source File Lines
bitsandbytes repo bitsandbytes/optim/adam.py Adam8bit L70-139
bitsandbytes repo bitsandbytes/optim/optimizer.py Optimizer2State L384-625

Class signature:

class Adam8bit(Optimizer2State):
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=0,
        amsgrad=False,
        optim_bits=32,
        args=None,
        min_8bit_size=4096,
        percentile_clipping=100,
        block_wise=True,
        is_paged=False,
    ):

Import:

import bitsandbytes as bnb

optimizer = bnb.optim.Adam8bit(model.parameters(), lr=1e-3)

I/O Contract

Inputs

Parameter Type Required Default Description
params iterable Yes -- Model parameters to optimize
lr float No 1e-3 Learning rate
betas tuple(float, float) No (0.9, 0.999) Decay rates for first and second moment estimates
eps float No 1e-8 Epsilon for numerical stability in denominator
weight_decay float No 0 Weight decay (L2 penalty)
min_8bit_size int No 4096 Minimum tensor size (number of elements) for 8-bit quantization; smaller tensors use 32-bit
percentile_clipping int No 100 Percentile for gradient clipping (100 disables clipping)
block_wise bool No True Whether to use block-wise quantization (recommended)
is_paged bool No False Whether to use paged memory for GPU-to-CPU offload

Outputs

Output Type Description
parameter updates in-place Model parameters are updated in-place
state1 torch.uint8 tensor Momentum (first moment), stored in 8-bit quantized format
state2 torch.uint8 tensor Variance (second moment), stored in 8-bit quantized format
absmax1 torch.float32 tensor Per-block scaling factors for momentum
absmax2 torch.float32 tensor Per-block scaling factors for variance
qmap1 torch.Tensor Signed dynamic quantization map (256 levels) for momentum
qmap2 torch.Tensor Unsigned dynamic quantization map (256 levels) for variance

Usage Examples

Standard training with Adam8bit:

import torch
import bitsandbytes as bnb

model = MyModel().cuda()
optimizer = bnb.optim.Adam8bit(
    model.parameters(),
    lr=1e-3,
    betas=(0.9, 0.999),
    weight_decay=0.01,
)

for epoch in range(num_epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        outputs = model(batch["input_ids"].cuda())
        loss = loss_fn(outputs, batch["labels"].cuda())
        loss.backward()
        optimizer.step()

With paged memory for large models:

optimizer = bnb.optim.PagedAdam8bit(
    model.parameters(),
    lr=1e-3,
)

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment